Variable tf.Variable has 'None' for gradient in TensorFlow Probability
Asked Answered
P

1

6

I'm having trouble constructing a basic BNN in TFP. I'm new to TFP and BNNs in general, so I apologize if I've missed something simple.

I can train a basic NN in Tensorflow by doing the following:

model = keras.Sequential([
    keras.layers.Dense(units=100, activation='relu'),
    keras.layers.Dense(units=50, activation='relu'),
    keras.layers.Dense(units=5, activation='softmax')
])

model.compile(optimizer=optimizer, 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    training_data.repeat(), 
    epochs=100, 
    steps_per_epoch=(X_train.shape[0]//1024),
    validation_data=test_data.repeat(), 
    validation_steps=2
)

However, I have trouble when trying to implement a similar architecture with tfp DenseFlipout layers:

model = keras.Sequential([
    tfp.layers.DenseFlipout(units=100, activation='relu'),
    tfp.layers.DenseFlipout(units=10, activation='relu'),
    tfp.layers.DenseFlipout(units=5, activation='softmax')
])

model.compile(optimizer=optimizer, 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    training_data.repeat(), 
    epochs=100, 
    steps_per_epoch=(X_train.shape[0]//1024),
    validation_data=test_data.repeat(), 
    validation_steps=2
)

I get the following Value error:

ValueError: 
Variable <tf.Variable 'sequential_11/dense_flipout_15/kernel_posterior_loc:0' 
shape=(175, 100) dtype=float32> has `None` for gradient. 
Please make sure that all of your ops have a gradient defined (i.e. are differentiable). 
Common ops without gradient: K.argmax, K.round, K.eval.

I've done some googling, and have looked around the TFP docs, but am at a loss so thought I would share the issue. Have I missed something obvious?

Thanks in advance.

Panchromatic answered 14/8, 2019 at 12:24 Comment(2)
Is the DenseFlipout layer compatible with data containing exact zeros? (Try using something different from 'relu' just to check). Hint: "softmax" is meant to be used with from_logits=False.Hourglass
Thanks for hand there on from_logits! I believe the DenseFlipout is, buuuuut I the problem looks likely due to the fact that I was using TF2.Panchromatic
K
3

I expect it's because you're using TensorFlow 2, are you? It isn't fully supported yet. If so, downgrading to 1.14 should get it working.

Kimon answered 15/8, 2019 at 14:1 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.