Generalized dice loss for multi-class segmentation: keras implementation
Asked Answered
C

1

10

I just implemented the generalised dice loss (multi-class version of dice loss) in keras, as described in ref :

(my targets are defined as: (batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

But something must be wrong. I'm working with 3D images that I have to segment for 4 classes (1 background class and 3 object classes, I have a imbalanced dataset). First odd thing: while my train loss and accuracy improve during training (and converge really fast), my validation loss/accuracy are constant trough epochs (see image). Second, when predicting on test data, only the background class is predicted: I get a constant volume.

I used the exact same data and script but with categorical cross-entropy loss and get plausible results (object classes are segmented). Which means something is wrong with my implementation. Any idea what it could be?

Plus I believe it would be usefull to the keras community to have a generalised dice loss implementation, as it seems to be used in most of recent semantic segmentation tasks (at least in the medical image community).

PS: it seems odd to me how the weights are defined; I get values around 10^-10. Anyone else has tried to implement this? I also tested my function without the weights but get same problems.

Cresol answered 27/2, 2018 at 15:14 Comment(0)
P
7

I think the problem here are your weights. Imagine you are trying to solve a multiclass segmentation problem, but in each image only a few classes are ever present. A toy example of this (and the one which led me to this problem) is to create a segmentation dataset from mnist in the following way.

x = 28x28 image and y = 28x28x11 where each pixel is classified as background if it is below a normalised grayscale value of 0.4, and otherwise is classified as the digit which is the original class of x. So if you see a picture of the number one, you will have a bunch of pixels classified as one, and the background.

Now in this dataset you will only ever have two classes present in the image. This means that, following your dice loss, 9 of the weights will be 1./(0. + eps) = large and so for every image we are strongly penalising all 9 non-present classes. An evidently strong local minima the network wants to find in this situation is to predict everything as a background class.

We do want to penalise any incorrectly predicted classes which are not in the image, but not so strongly. So we just need to modify the weights. This is how I did it:

def gen_dice(y_true, y_pred, eps=1e-6):
    """both tensors are [b, h, w, classes] and y_pred is in logit form"""

    # [b, h, w, classes]
    pred_tensor = tf.nn.softmax(y_pred)
    y_true_shape = tf.shape(y_true)

    # [b, h*w, classes]
    y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
    y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])

    # [b, classes]
    # count how many of each class are present in 
    # each image, if there are zero, then assign
    # them a fixed weight of eps
    counts = tf.reduce_sum(y_true, axis=1)
    weights = 1. / (counts ** 2)
    weights = tf.where(tf.math.is_finite(weights), weights, eps)

    multed = tf.reduce_sum(y_true * y_pred, axis=1)
    summed = tf.reduce_sum(y_true + y_pred, axis=1)

    # [b]
    numerators = tf.reduce_sum(weights*multed, axis=-1)
    denom = tf.reduce_sum(weights*summed, axis=-1)
    dices = 1. - 2. * numerators / denom
    dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
    return tf.reduce_mean(dices)
Pneumodynamics answered 17/11, 2019 at 14:44 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.