Logsoftmax stability
Asked Answered
R

4

8

I know how to make softmax stable by adding to element -max _i x_i. This avoids overflow and underflow. Now, taking log of this can cause underflow. log softmax(x) can evaluate to zero, leading to -infinity.

I am not sure how to fix it. I know this is a common problem. I read several answers on it, which I didn't understand. But I am still confused on how to solve this problem.

PS: If you provide a simple example, it would be awesome.

Raptorial answered 20/5, 2017 at 1:27 Comment(0)
O
13

In order to stabilize Logsoftmax, most implementations such as Tensorflow and Thenao, use a trick which takes out the largest component max(x_i). This trick is often used for stably computing softmax. For logsoftmax, we begin with:

formula

After extracting out the exp(b) and using the fact that log(exp(x)) = x, we have:

formula

If we set b = max(x_i), this new equation has both overflow and underflow stability conditions.


In terms of code, if x is a vector:

def log_softmax(x):
    x_off = x - np.max(x)
    return x_off - np.log(np.sum(np.exp(x_off)))

See also: https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/

Octagon answered 1/9, 2018 at 20:42 Comment(0)
M
1
logsoftmax = logits - log(reduce_sum(exp(logits), dim))

refer: https://www.tensorflow.org/api_docs/python/tf/nn/log_softmax

Mailman answered 12/12, 2017 at 8:41 Comment(0)
A
0

Just use this as it take care of Nan

tf.nn.softmax_cross_entropy_with_logits(
    labels, logits, axis=-1, name=None
)
logits = tf.constant([[4, 5, 1000]], dtype = tf.float32)
labels = tf.constant([[1,0,1]], dtype = tf.float32)

# Case-1 
output = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(output) 
>>> tf.Tensor([996.], shape=(1,), dtype=float32)

#Case-2
a = tf.nn.softmax(logits)
output = tf.reduce_sum(-(labels * tf.math.log(a)))
print(output) 
>>> tf.Tensor(nan, shape=(), dtype=float32)


# this happens because value of softmax truncates to zero

print(a) 
>>> <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0., 0., 1.]], dtype=float32)>
Almund answered 22/5, 2021 at 8:45 Comment(0)
B
-1

Mathematical tricks cannot help you create log 0 be something other that -inf. If you think it trough, the only way is you normalize the data so that you don't end in there.

Butterfish answered 8/7, 2019 at 23:51 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.