Tensorflow, negative KL Divergence
Asked Answered
H

1

7

I am working with a Variational Autoencoder Type model and part of my loss function is the KL divergence between a Normal Distribution with mean 0 and variance 1 and another Normal Distribution whose mean and variance are predicted by my model.

I defined the loss in the following way:

def kl_loss(mean, log_sigma):
    normal=tf.contrib.distributions.MultivariateNormalDiag(tf.zeros(mean.get_shape()),
                                                           tf.ones(log_sigma.get_shape()))
    enc_normal = tf.contrib.distributions.MultivariateNormalDiag(mean,
                                                                     tf.exp(log_sigma),
                                                                     validate_args=True,
                                                                     allow_nan_stats=False,
                                                                     name="encoder_normal")
    kl_div = tf.contrib.distributions.kl_divergence(normal,
                                                    enc_normal,
                                                    allow_nan_stats=False,
                                                    name="kl_divergence")
return kl_div

The input are unconstrained vectors of length N with

log_sigma.get_shape() == mean.get_shape()

Now during training I observe a negative KL divergence after a few thousand iterations up to values of -10. Below you can see the Tensorboard training curves:

KL divergence curve

Zoom in of KL divergence curve

Now this seems odd to me as the KL divergence should be positive under certain conditions. I understand that we require "The K-L divergence is only defined if P and Q both sum to 1 and if Q(i) > 0 for any i such that P(i) > 0." (see https://mathoverflow.net/questions/43849/how-to-ensure-the-non-negativity-of-kullback-leibler-divergence-kld-metric-rela) but I don't see how this could be violated in my case. Any help is highly appreciated!

Hern answered 2/3, 2018 at 11:6 Comment(8)
what is your final layer's activation function?Rosary
The last layer is a 3D conv layer with linear (None) activation function (tensorflow.org/api_docs/python/tf/layers/conv3d) and kernel size 1. I flatten the resulting tensor and the first half ends up being my mean, the second half log_sigma.Hern
So the output from final layer can be larger than 1 then?Rosary
Yes but why does that matter?Hern
Kl divergence is calculated on probability distribution. I am wondering how you are converting a number output to a probability distribution.Rosary
Well yes I am aware. It’s pretty clear from the code snippet and my post though that I am initializing two multivariate normal distributions with diagonal covariance matrix. The mean and variance of one of these distributions is determined by my networks output. The only constrain in these is that sigma is positive. This is ensured by taking exp(log-sigma). So what is it you are asking?Hern
I don't think I will be able to help with this, you are getting a negative KL divergence which can happen if the formula is wrong or the input is wrong. My guess is your input is probably wrong, And your input we can'g get without looking at the code. So not much help I can provide. sorry..Rosary
It has nothing to do with the input but you clearly do not read what I am writing. The function above should return a positive KL divergence for ANY inputs. I suspect it's a numerical issue and has to do with the way this function is implemented in Tensorflow. I would simply like to hear another opinion about this from somebody who stumbled upon the same issue.Hern
P
1

Faced the same problem. It happened because of float precision used. If you notice the negative values occur close to 0 and is bounded to a small negative value. Adding a small positive value to the loss is a work around.

Plural answered 13/11, 2020 at 11:49 Comment(1)
Hi! Please be more elaborate in your answer and provide a solution or you can move this answer to comment section.Wharve

© 2022 - 2024 — McMap. All rights reserved.