what is the difference of torch.nn.Softmax, torch.nn.funtional.softmax, torch.softmax and torch.nn.functional.log_softmax
Asked Answered
A

2

9

I tried to find documents but cannot find anything about torch.softmax.

What is the difference among torch.nn.Softmax, torch.nn.funtional.softmax, torch.softmax and torch.nn.functional.log_softmax?

Examples are appreciated.

Apfel answered 17/9, 2021 at 3:8 Comment(0)
E
14
import torch

x = torch.rand(5)

x1 = torch.nn.Softmax()(x)
x2 = torch.nn.functional.softmax(x)
x3 = torch.nn.functional.log_softmax(x)

print(x1)
print(x2)
print(torch.log(x1))
print(x3)
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])

torch.nn.Softmax and torch.nn.functional.softmax gives identical outputs, one is a class (pytorch module), another one is a function. log_softmax applies log after applying softmax.

NLLLoss takes log-probabilities (log(softmax(x))) as input. So, you would need log_softmax for NLLLoss, log_softmax is numerically more stable, usually yields better results.

Ecg answered 17/9, 2021 at 3:28 Comment(0)
I
1

import torch
import torch.nn as nn


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.LazyLinear(128)
        self.activation = nn.ReLU()
        self.layer_2 = nn.Linear(128, 10)
        self.output_function = nn.Softmax(dim=1)

    def forward(self, x, softmax="module"):
        y = self.layer_1(x)
        y = self.activation(y)
        y = self.layer_2(y)
        if softmax == "module":
            return self.output_function(y)

        # OR
        if softmax == "torch":
            return torch.softmax(y, dim=1)

        # OR (deprecated)
        if softmax == "functional":
            return nn.functional.softmax(y, dim=1)

        # OR (careful, the reason why the log is there is to ensure
        # numerical stability so you should use torch.exp wisely)
        if softmax == "log":
            return torch.exp(torch.log_softmax(y, dim=1))

        raise ValueError(f"Unknown softmax type {softmax}")


x = torch.rand(2, 2)
net = Network()

for s in ["module", "torch", "log"]:
    print(net(x, softmax=s))

Basically nn.Softmax() creates a module, so it returns a function whereas the others are pure functions.

Why would you need a log softmax? Well an example lies in the docs of nn.Softmax:

This module doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use LogSoftmax instead (it's faster and has better numerical properties).

See also What is the difference between log_softmax and softmax?

Inhaler answered 17/9, 2021 at 3:19 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.