why softmax get small gradient when the value is large in paper 'Attention is all you need'
Asked Answered
N

3

6

This is the screen of the original paper: the screen of the paper. I understand the meaning of the paper is that when the value of dot-product is large, the gradient of softmax will get very small.
However, I tried to calculate the gradient of softmax with the cross entropy loss and found that the gradient of softmax is not directly related to value passed to softmax.
Even the single value is large, it still can get a large gradient when ather values are large. (sorry about that I don't know how to pose the calculation process here)

Ningpo answered 27/2, 2019 at 12:42 Comment(0)
M
5

Actually the gradient of cross entropy with softmax on a one hot encoding vector is just grad -log(softmax(x)) = (1 - softmax(x)) at the index of the vector of the corresponding class. (https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/). If the value passed to the softmax is large, the softmax will produce 1 and therefore produce 0 gradient.

Molina answered 27/2, 2019 at 12:55 Comment(3)
Hi, I may did not expressed clearly in the question. The point is that I think the value X passed to softmax should be a vector like [ x1,...xi...xn ], so it doesn't matter if a single value 'xi' is large as long as each xi in X is with the same magnitude, then the result of softmax would not be equal to 1. Am I right?Ningpo
Yes, but minor deviation will easily dominate in the softmax if you blow the values up. E.g. consider logits which are 0.3 and 0.4. Then the softmax will not be one. But if you multiply both numbers by 100 to 30 and 40, then the softmax will be 1, even though the relative difference is the same.Molina
I don't think this should be explained by one-hot example, in paper, the softmax result is attention, and the follow up operation is multiply a value matrix, not CE with a one hot vector.Johansen
W
0

A little late to the game and just starting my career in NLP, but I think the core concept driving vanishing gradients with large inputs to the softmax function comes back to the definition of softmax:

softmax(x, X="some domain") = exp(x)/(Σexp(x')|x'∈X)

The core intuition is that, assuming that x∈[200,1000], exp(x+1) is so much larger than exp(x) that softmax computations like the following end up approaching 1, which causes the gradient to approach 0:

exp(500)/(exp(300) + exp(400) + exp(500)) ≈ 1

To further illustrate how the problem can occur in less extreme cases, see how close this softmax computation comes to 1:

exp(500)/(exp(490) + exp(495) + exp(500)) ≈ 0.993262

Now take 500, 495, and 490 and scale them down by a simple factor of 10^2:

exp(5)/(exp(4.9) + exp(4.95) + exp(5)) ≈ 0.350131861449

Clearly, this last one gives us a more fair conversion to a probability distribution. To me, it appears to be a limitation of the softmax expression that using high numerical values will produce less usable results.

I hope this explanation helps clarify what is going on here.

Wrongful answered 10/11, 2023 at 2:2 Comment(0)
P
0

here's an illustration of the cross entropy gradient for a two-class case gradient (arrow) of cross entropy loss w.r.t. logits

import torch
import torch.nn.functional as F
import numpy as np
import plotly.figure_factory as ff
import scipy.special
import matplotlib.pyplot as plt

# Create a synthetic dataset
z1_vals = np.linspace(-4, 4, 40)
z2_vals = np.linspace(-4, 4, 40)
mesh_z1, mesh_z2 = np.meshgrid(z1_vals, z2_vals)

num_classes = 2
const_one_hot_targets = torch.zeros(num_classes)
const_one_hot_targets[0] = 1

mesh_u_grad = np.zeros_like(mesh_z1)
mesh_v_grad = np.zeros_like(mesh_z2)
for i in range(mesh_z1.shape[0]):
    for j in range(mesh_z2.shape[1]):
        z = np.array([mesh_z1[i, j], mesh_z2[i, j]])
        logits = torch.Tensor(z)
        logits.requires_grad = True
        ce_loss = F.cross_entropy(logits, const_one_hot_targets)
        ce_loss.backward()
        assert logits.grad is not None
        assert logits.grad.shape == logits.shape
        grad = logits.grad.detach().numpy()
        mesh_u_grad[i, j] = -grad[0]
        mesh_v_grad[i, j] = -grad[1]
        logits.grad.zero_() 



# Create quiver plot for cross-entropy gradients
u_plot = mesh_u_grad.flatten()
v_plot = mesh_v_grad.flatten()
uv_norm = (u_plot ** 2 + v_plot ** 2) ** 0.5 
fig = ff.create_quiver(
    mesh_z1.flatten(),
    mesh_z2.flatten(),    
    u_plot,
    v_plot,
    scale=0.25 / uv_norm.max(),
    arrow_scale=0.3,
    name='Gradient',
    line_color='blue',
)
# Update layout
fig.update_layout(
    title='Gradient Vectors of CE Loss with Respect to Logits (K=2)',
    xaxis_title='Logit z1',
    yaxis_title='Logit z2',
    showlegend=True,
)
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)

fig.show()
Pepsinogen answered 30/6, 2024 at 23:55 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.