Disable grad and backward Globally?
Asked Answered
R

1

6

How to disable GLOBALLY grad,backward and any other non forward() functionality in Torch ?

I see examples of how to do it locally but not globally ?

The Docs say that what may be I'm looking is Inference only mode ! but how to set it globally.

Reata answered 1/9, 2021 at 3:6 Comment(0)
L
7

You can use torch.set_grad_enabled(False) to disable gradient propagation globally for the entire thread. Besides, after you called torch.set_grad_enabled(False), doing anything like backward() will raise an exception.

a = torch.tensor(np.random.rand(64,5),dtype=torch.float32)
l = torch.nn.Linear(5,10)

o = torch.sum(l(a))
print(o.requires_grad) #True
o.backward()
print(l.weight.grad) #showed gradients

torch.set_grad_enabled(False)

o = torch.sum(l(a))
print(o.requires_grad) #False
o.backward()# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
print(l.weight.grad)
Lw answered 1/9, 2021 at 7:36 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.