tf.GradientTape returns None for gradient
Asked Answered
D

1

1

I am using tf.GradientTape().gradient() to compute a representer point, which can be used to compute the "influence" of a given training example on a given test example. A representer point for a given test example x_t and training example x_i is computed as the dot product of their feature representations, f_t and f_i, multiplied by a weight alpha_i.

Note: The details of this approach are not necessary for understanding the question, since the main issue is getting gradient tape to work. That being said, I have included a screenshot of the some of the details below for anyone who is interested.

Computing alpha_i requires differentiation, since it is expressed as the following:

enter image description here

In the equation above L is the standard loss function (categorical cross-entropy for multiclass classification) and phi is the pre-softmax activation output (so its length is the number of classes). Furthermore alpha_i can be further broken up into alpha_ij, which is computed with respect to a specific class j. Therefore, we just obtain the pre-softmax output phi_j corresponding to the predicted class of the test example (class with highest final prediction).

I have created a simple setup with MNIST and have implemented the following:

def simple_mnist_cnn(input_shape = (28,28,1)):
  input = Input(shape=input_shape)
  x = layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(input)
  x = layers.MaxPooling2D(pool_size=(2, 2))(x)
  x = layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(x)
  x = layers.MaxPooling2D(pool_size=(2, 2))(x)
  x = layers.Flatten()(x) # feature representation 
  output = layers.Dense(num_classes, activation=None)(x) # presoftmax activation output 
  activation = layers.Activation(activation='softmax')(output) # final output with activation 
  model = tf.keras.Model(input, [x, output, activation], name="mnist_model")
  return model

Now assume the model is trained, and I want to compute the influence of a given train example on a given test example's prediction, perhaps for model understanding/debugging purposes.

with tf.GradientTape() as t1:
  f_t, _, pred_t = model(x_t) # get features for misclassified example
  f_i, presoftmax_i, pred_i = model(x_i)

  # compute dot product of feature representations for x_t and x_i
  dotps = tf.reduce_sum(
            tf.multiply(f_t, f_i))

  # get presoftmax output corresponding to highest predicted class of x_t
  phi_ij = presoftmax_i[:,np.argmax(pred_t)]

  # y_i is actual label for x_i
  cl_loss_i = tf.keras.losses.categorical_crossentropy(pred_i, y_i)

alpha_ij = t1.gradient(cl_loss_i, phi_ij)
# note: alpha_ij returns None currently
k_ij = tf.reduce_sum(tf.multiply(alpha_i, dotps))

The code above gives the following error, since alpha_ij is None: ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.. However, if I change t1.gradient(cl_loss_i, phi_ij) -> t1.gradient(cl_loss_i, presoftmax_i), it no longer returns None. Not sure why this is the case? Is there an issue with computing gradients on sliced tensors? Is there an issue with "watching" too many variables? I haven't worked much with gradient tape so I'm not sure what the fix is, but would appreciate help.

For anyone who is interested, here are more details:enter image description here

Danie answered 1/8, 2021 at 22:23 Comment(2)
Why are you using numpy inside a tensorflow gradient? That is almost surely the issue.Cowling
I don't know if it an issue? For example, the tensorflow documentation uses numpy operations inside gradient tape in some cases: tensorflow.org/guide/autodiff. But to be sure, I switched np.argmax(pred_t) to a fixed index (e.g. 0), and the issue still persisted.Danie
D
2

I never see you watch any tensors. Note that the tape only traces tf.Variable by default. Is this missing from your code? Else I don't see how t1.gradient(cl_loss_i, presoftmax_i) is working.

Either way, I think the easiest way to fix it is to do

all_gradients = t1.gradient(cl_loss_i, presoftmax_i)
desired_gradients = all_gradients[[:,np.argmax(pred_t)]]

so simply do the indexing after the gradient. Note that this can be wasteful (if there are many classes) as you are computing more gradients than you need.

The explanation for why (I believe) your version doesn't work would be easiest to show in a drawing, but let me try to explain: Imagine the computations in a directed graph. We have

presoftmax_i -> pred_i -> cl_loss_i

Backpropagating the loss to the presoftmax is easy. But then you set up another branch,

presoftmax_i -> presoftmax_ij

Now, when you try to compute the gradient of the loss with respect to presoftmax_ij, there is actually no backpropagation path (we can only follow arrows backwards). Another way to think about it: You compute presoftmax_ij after computing the loss. How could the loss depend on it then?

Downspout answered 2/8, 2021 at 8:45 Comment(3)
Thanks for your response! That was my concern as well. Also, I didn't forget to include a watch, so I'm not sure how that is working either. Do you know how I'd watch a variable that isn't defined until inside gradient tape (such as presoftmax_ij)? One more thing to note (that might be obvious) is that any indexing seems to be an issue. For example, I changed presoftmax_i to presoftmax_i[:,:] which is equivalent, and the latter returns None while the earlier one does not.Danie
Also, would you expect anything to change if I computed presoftmax_ij before computing the cl_loss_i? I tried changing that and that doesn't seem to help either.Danie
1. You can watch variables at any point during the tape, so it shouldn't be a problem to call it on a tensor after it is defined (even inside the tape), but I'm not 100% sure. 2. Regarding indexing [:,:], that creates a copy of the tensor so I guess it will be treated as a "new" result, and lead to the same issues as indexing with some i. 3. Regarding computing _ij before the loss -- hard to say because it seems like you edited the question and _ij is gone. ;) But I don't think it will help if the loss is not computed based on _ij.Downspout

© 2022 - 2024 — McMap. All rights reserved.