I tried to find documents but cannot find anything about torch.softmax.
What is the difference among torch.nn.Softmax, torch.nn.funtional.softmax, torch.softmax and torch.nn.functional.log_softmax?
Examples are appreciated.
I tried to find documents but cannot find anything about torch.softmax.
What is the difference among torch.nn.Softmax, torch.nn.funtional.softmax, torch.softmax and torch.nn.functional.log_softmax?
Examples are appreciated.
import torch
x = torch.rand(5)
x1 = torch.nn.Softmax()(x)
x2 = torch.nn.functional.softmax(x)
x3 = torch.nn.functional.log_softmax(x)
print(x1)
print(x2)
print(torch.log(x1))
print(x3)
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
torch.nn.Softmax
and torch.nn.functional.softmax
gives identical outputs, one is a class (pytorch module), another one is a function.
log_softmax
applies log after applying softmax.
NLLLoss takes log-probabilities (log(softmax(x))) as input. So, you would need log_softmax for NLLLoss, log_softmax is numerically more stable, usually yields better results.
import torch
import torch.nn as nn
class Network(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.LazyLinear(128)
self.activation = nn.ReLU()
self.layer_2 = nn.Linear(128, 10)
self.output_function = nn.Softmax(dim=1)
def forward(self, x, softmax="module"):
y = self.layer_1(x)
y = self.activation(y)
y = self.layer_2(y)
if softmax == "module":
return self.output_function(y)
# OR
if softmax == "torch":
return torch.softmax(y, dim=1)
# OR (deprecated)
if softmax == "functional":
return nn.functional.softmax(y, dim=1)
# OR (careful, the reason why the log is there is to ensure
# numerical stability so you should use torch.exp wisely)
if softmax == "log":
return torch.exp(torch.log_softmax(y, dim=1))
raise ValueError(f"Unknown softmax type {softmax}")
x = torch.rand(2, 2)
net = Network()
for s in ["module", "torch", "log"]:
print(net(x, softmax=s))
Basically nn.Softmax()
creates a module, so it returns a function whereas the others are pure functions.
Why would you need a log softmax? Well an example lies in the docs of nn.Softmax
:
This module doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use
LogSoftmax
instead (it's faster and has better numerical properties).
See also What is the difference between log_softmax and softmax?
© 2022 - 2024 — McMap. All rights reserved.