I understood the concept of automatic differentiation, but couldn't find any explanation how tensorflow calculates the error gradient for non differentiable functions as for example tf.where
in my loss function or tf.cond
in my graph. It works just fine, but I would like to understand how tensorflow backpropagates the error through such nodes, since there is no formula to calculate the gradient from them.
In the case of tf.where
, you have a function with three inputs, condition C
, value on true T
and value on false F
, and one output Out
. The gradient receives one value and has to return three values. Currently, no gradient is computed for the condition (that would hardly make sense), so you just need to do the gradients for T
and F
. Assuming the input and the outputs are vectors, imagine C[0]
is True
. Then Out[0]
comes from T[0]
, and its gradient should propagate back. On the other hand, F[0]
would have been discarded, so its gradient should be made zero. If Out[1]
were False
, then the gradient for F[1]
should propagate but not for T[1]
. So, in short, for T
you should propagate the given gradient where C
is True
and make it zero where it is False
, and the opposite for F
. If you look at the implementation of the gradient of tf.where
(Select
operation), it does exactly that:
@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
return (None, array_ops.where(c, grad, zeros), array_ops.where(
c, zeros, grad))
Note the input values themselves are not used in the computation, that will be done by the gradients of the operation producing those inputs. For tf.cond
, the code is a bit more complicated, because the same operation (Merge
) is used in different contexts, and also tf.cond
also uses Switch
operations inside. However the idea is the same. Essentially, Switch
operations are used for each input, so the input that was activated (the first if the condition was True
and the second otherwise) gets the received gradient and the other input gets a "switched off" gradient (like None
), and does not propagate back further.
© 2022 - 2024 — McMap. All rights reserved.