Given batched RGB images as input, shape=(batch_size, width, height, 3)
And a multiclass target represented as one-hot, shape=(batch_size, width, height, n_classes)
And a model (Unet, DeepLab) with softmax activation in last layer.
I'm looking for weighted categorical-cross-entropy loss funciton in kera/tensorflow.
The class_weight
argument in fit_generator
doesn't seems to work, and I didn't find the answer here or in https://github.com/keras-team/keras/issues/2115.
def weighted_categorical_crossentropy(weights):
# weights = [0.9,0.05,0.04,0.01]
def wcce(y_true, y_pred):
# y_true, y_pred shape is (batch_size, width, height, n_classes)
loos = ?...
return loss
return wcce