Multi-class weighted loss for semantic image segmentation in keras/tensorflow
Asked Answered
H

3

5

Given batched RGB images as input, shape=(batch_size, width, height, 3)

And a multiclass target represented as one-hot, shape=(batch_size, width, height, n_classes)

And a model (Unet, DeepLab) with softmax activation in last layer.

I'm looking for weighted categorical-cross-entropy loss funciton in kera/tensorflow.

The class_weight argument in fit_generator doesn't seems to work, and I didn't find the answer here or in https://github.com/keras-team/keras/issues/2115.

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce
Hakon answered 29/12, 2019 at 15:40 Comment(3)
By multiclass target do you mean more than 1 possible outcomes are considered?Gerlach
What do you mean by "outcome"? Multiclass=Different pixel value indicate different class. And you can have more than 2 classes. (2 classes=binary classification)Hakon
Multiclass classification is a different kind of classification problem where more than 1 class can be true, I got confused with that.Gerlach
H
6

I will answer my question:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

Usage:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)
Hakon answered 29/12, 2019 at 20:57 Comment(0)
B
1

I'm using the Generalized Dice Loss. It works better than the Weighted Categorical Crossentropy in my case. My implementation is in PyTorch, however, it should be fairly easy to translate it.

class GeneralizedDiceLoss(nn.Module):
    def __init__(self):
        super(GeneralizedDiceLoss, self).__init__()

    def forward(self, inp, targ):
        inp = inp.contiguous().permute(0, 2, 3, 1)
        targ = targ.contiguous().permute(0, 2, 3, 1)

        w = torch.zeros((targ.shape[-1],))
        w = 1. / (torch.sum(targ, (0, 1, 2))**2 + 1e-9)

        numerator = targ * inp
        numerator = w * torch.sum(numerator, (0, 1, 2))
        numerator = torch.sum(numerator)

        denominator = targ + inp
        denominator = w * torch.sum(denominator, (0, 1, 2))
        denominator = torch.sum(denominator)

        dice = 2. * (numerator + 1e-9) / (denominator + 1e-9)

        return 1. - dice
Bisson answered 30/4, 2020 at 19:45 Comment(0)
L
0

This issue might be similar to: Unbalanced data and weighted cross entropy which has an accepted answer.

Leighleigha answered 29/12, 2019 at 17:22 Comment(1)
No it's not. I'm asking about pixelwise classification.Hakon

© 2022 - 2024 — McMap. All rights reserved.