How does tensorflow handle non differentiable nodes during gradient calculation?
Asked Answered
L

1

10

I understood the concept of automatic differentiation, but couldn't find any explanation how tensorflow calculates the error gradient for non differentiable functions as for example tf.where in my loss function or tf.cond in my graph. It works just fine, but I would like to understand how tensorflow backpropagates the error through such nodes, since there is no formula to calculate the gradient from them.

Lilt answered 8/11, 2018 at 13:3 Comment(0)
M
8

In the case of tf.where, you have a function with three inputs, condition C, value on true T and value on false F, and one output Out. The gradient receives one value and has to return three values. Currently, no gradient is computed for the condition (that would hardly make sense), so you just need to do the gradients for T and F. Assuming the input and the outputs are vectors, imagine C[0] is True. Then Out[0] comes from T[0], and its gradient should propagate back. On the other hand, F[0] would have been discarded, so its gradient should be made zero. If Out[1] were False, then the gradient for F[1] should propagate but not for T[1]. So, in short, for T you should propagate the given gradient where C is True and make it zero where it is False, and the opposite for F. If you look at the implementation of the gradient of tf.where (Select operation), it does exactly that:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

Note the input values themselves are not used in the computation, that will be done by the gradients of the operation producing those inputs. For tf.cond, the code is a bit more complicated, because the same operation (Merge) is used in different contexts, and also tf.cond also uses Switch operations inside. However the idea is the same. Essentially, Switch operations are used for each input, so the input that was activated (the first if the condition was True and the second otherwise) gets the received gradient and the other input gets a "switched off" gradient (like None), and does not propagate back further.

Maryettamaryjane answered 8/11, 2018 at 16:51 Comment(1)
Quite obvious how it is working, now that I read your answer!Lilt

© 2022 - 2024 — McMap. All rights reserved.