This is the screen of the original paper: the screen of the paper. I understand the meaning of the paper is that when the value of dot-product is large, the gradient of softmax will get very small.
However, I tried to calculate the gradient of softmax with the cross entropy loss and found that the gradient of softmax is not directly related to value passed to softmax.
Even the single value is large, it still can get a large gradient when ather values are large. (sorry about that I don't know how to pose the calculation process here)
Actually the gradient of cross entropy with softmax on a one hot encoding vector is just grad -log(softmax(x)) = (1 - softmax(x)) at the index of the vector of the corresponding class. (https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/). If the value passed to the softmax is large, the softmax will produce 1 and therefore produce 0 gradient.
A little late to the game and just starting my career in NLP, but I think the core concept driving vanishing gradients with large inputs to the softmax function comes back to the definition of softmax:
softmax(x, X="some domain") = exp(x)/(Σexp(x')|x'∈X)
The core intuition is that, assuming that x∈[200,1000]
, exp(x+1)
is so much larger than exp(x)
that softmax computations like the following end up approaching 1, which causes the gradient to approach 0:
exp(500)/(exp(300) + exp(400) + exp(500)) ≈ 1
To further illustrate how the problem can occur in less extreme cases, see how close this softmax computation comes to 1:
exp(500)/(exp(490) + exp(495) + exp(500)) ≈ 0.993262
Now take 500, 495, and 490 and scale them down by a simple factor of 10^2:
exp(5)/(exp(4.9) + exp(4.95) + exp(5)) ≈ 0.350131861449
Clearly, this last one gives us a more fair conversion to a probability distribution. To me, it appears to be a limitation of the softmax expression that using high numerical values will produce less usable results.
I hope this explanation helps clarify what is going on here.
here's an illustration of the cross entropy gradient for a two-class case gradient (arrow) of cross entropy loss w.r.t. logits
import torch
import torch.nn.functional as F
import numpy as np
import plotly.figure_factory as ff
import scipy.special
import matplotlib.pyplot as plt
# Create a synthetic dataset
z1_vals = np.linspace(-4, 4, 40)
z2_vals = np.linspace(-4, 4, 40)
mesh_z1, mesh_z2 = np.meshgrid(z1_vals, z2_vals)
num_classes = 2
const_one_hot_targets = torch.zeros(num_classes)
const_one_hot_targets[0] = 1
mesh_u_grad = np.zeros_like(mesh_z1)
mesh_v_grad = np.zeros_like(mesh_z2)
for i in range(mesh_z1.shape[0]):
for j in range(mesh_z2.shape[1]):
z = np.array([mesh_z1[i, j], mesh_z2[i, j]])
logits = torch.Tensor(z)
logits.requires_grad = True
ce_loss = F.cross_entropy(logits, const_one_hot_targets)
ce_loss.backward()
assert logits.grad is not None
assert logits.grad.shape == logits.shape
grad = logits.grad.detach().numpy()
mesh_u_grad[i, j] = -grad[0]
mesh_v_grad[i, j] = -grad[1]
logits.grad.zero_()
# Create quiver plot for cross-entropy gradients
u_plot = mesh_u_grad.flatten()
v_plot = mesh_v_grad.flatten()
uv_norm = (u_plot ** 2 + v_plot ** 2) ** 0.5
fig = ff.create_quiver(
mesh_z1.flatten(),
mesh_z2.flatten(),
u_plot,
v_plot,
scale=0.25 / uv_norm.max(),
arrow_scale=0.3,
name='Gradient',
line_color='blue',
)
# Update layout
fig.update_layout(
title='Gradient Vectors of CE Loss with Respect to Logits (K=2)',
xaxis_title='Logit z1',
yaxis_title='Logit z2',
showlegend=True,
)
fig.update_yaxes(
scaleanchor="x",
scaleratio=1,
)
fig.show()
© 2022 - 2025 — McMap. All rights reserved.