Tensorflow issue with softmax
Asked Answered
L

1

6

I have a Tensorflow multiclass classifier that is generating nan or inf while computing probabilities using tf.nn.softmax. See the following snippet (logits is of shape batch_size x 6, since I have 6 classes and the output is one-hot encoded). batch_size is 1024.

logits = tf.debugging.check_numerics(logits, message='bad logits', name=None)
probabilities = tf.nn.softmax(logits=logits, name='Softmax')
probabilities = tf.debugging.check_numerics(probabilities, message='bad probabilities', name=None)

The classifier fails on the last statement as it finds nan or inf in probabilities. logits are clean, otherwise the first statement would have failed.

From what I read about tf.nn.softmax, it can handle very large and very small values in logits. I have verified this in interactive mode.

>>> with tf.Session() as s:
...   a = tf.constant([[1000, 10], [-100, -200], [3, 4.0]])
...   sm = tf.nn.softmax(logits=a, name='Softmax')
...   print(a.eval())
...   print(sm.eval())
...
[[1000.   10.]
 [-100. -200.]
 [   3.    4.]]
[[1.         0.        ]
 [1.         0.        ]
 [0.26894143 0.7310586 ]]

I then tried clipping the values in logits and the whole thing now works. See the modified snippet below.

logits = tf.debugging.check_numerics(logits, message='logits', name=None)
safe_logits = tf.clip_by_value(logits, -15.0, 15.0)
probabilities = tf.nn.softmax(logits=safe_logits, name='Softmax')
probabilities = tf.debugging.check_numerics(probabilities, message='bad probabilities', name=None)

In second statement, I am clipping the values in logits to -15 and 15, and that somehow prevents nan/inf in softmax computation. So, I was able to fix the issue at hand.

However, I still don't understand why this clipping is working? (I should mention that clipping between -20 and 20 does not work and the model fails with nan or inf in probabilities).

Could someone help me understand why this is the case?

I am using tensorflow 1.15.0, running on a 64-bit instance.

Lazaro answered 30/8, 2021 at 18:35 Comment(4)
How do you compute logits?Thomasson
logits are the outputs of the previous layer (right before head).Lazaro
I tried your code with tensorflow 2.0, and there was no error you said.Lutherlutheran
It is hard to reproduce this error on a sample. The job runs for 100K steps before this happens.Lazaro
L
4

The first place to look was the values themselves, which you already did. The second place to look would be the gradients. Even if the value appears reasonable, if the gradient is very steep, backprop will eventually explode the gradient and value.

For example, if the logits are generated by something like log(x), an x of 0.001 will generate -6.9. Looks pretty benign. But the gradient is 1000! That would quickly explode the gradients and values during backprop / forward prop.

# Pretend this is the source value that is fed to a function that generates the logit. 
>>> x = tf.Variable(0.001)

# Let's operate on the source value to generate the logit. 
>>> with tf.GradientTape() as tape:
...   y = tf.math.log(x)
... 

# The logit looks okay... -6.9. 
>>> y
<tf.Tensor: shape=(), dtype=float32, numpy=-6.9077554>

# But the gradient is exploding. 
>>> tape.gradient(y,x)
<tf.Tensor: shape=(), dtype=float32, numpy=999.99994>
>>> 

Clipping the logit would appear to focus on generating smaller values to feed to softmax, but that's probably not why it's helping. (In fact, softmax can handle a logit with value tf.float32.max no problem, so it's really unlikely the value of the logit is the issue). What may really be happening is that when you clip to 15, you are also setting the gradient to zero when the logit would otherwise be 20 with an explosive gradient. So clipping the value also introduces a clipped gradient.

# This is same source variable as above. 
>>> x = tf.Variable(0.001)

# Now let's operate with clipping. 
>>> with tf.GradientTape() as tape:
...   y = tf.clip_by_value(tf.math.log(x), -1., 1.)
... 

# The clipped logit still looks okay... 
>>> y
<tf.Tensor: shape=(), dtype=float32, numpy=-1.0>

# What may be more important is that the clipping has also zeroed out the gradient
>>> tape.gradient(y,x)
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Linnette answered 2/9, 2021 at 19:23 Comment(7)
Your explanation seems plausible but it still does not explain why logits are ok (no nans/infs) but when we compute softmax, that shows up with nan/inf. Is there a possibility that with graph computation, all operations, including the gradients of the previous step are computed at the time of computing softmax? (it still doesn't explain why logits would be ok, though).Lazaro
Does your Nan/inf occur during model.fit or during model.predict / model.evaluate / model.__call__?Linnette
I think it happens during fit. This piece of code is common to both fit and predict so it is hard to tell exactly where it is failing. But it fails during training, around 100k steps. Every few minutes, there is a callback to evaluate as well. So, I am not sure if it is fit or predict/evaluate.Lazaro
Yup, that would support the idea I hypothesized. Although you could sometimes print the logits and see them as okay, during backprop, at some point the gradient gets large, like 1000, then explodes into inf/nan. It may not be the logit triggering this, it could be upstream. When this happens to me I usually try reducing the LR, and sometimes clip the gradients, both of which are kind of cheap ways to solve the problem. Adding batch norms can help. Then I add a checkpoint callback so I can get the model back to a state that's about to explode.Linnette
And if you really want to go deep into it, you can build a custom training loop using tf.GradientTape and record the gradients directly. keras.io/guides/writing_a_training_loop_from_scratch/…Linnette
Oh yea, the first thing I’d do is stick tf.debugging.check-numerics everywhere. This will help you isolate the op that is actually triggering the nan.Linnette
Thanks. I'll try some of the suggestions you mentioned.Lazaro

© 2022 - 2024 — McMap. All rights reserved.