How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification
Asked Answered
R

5

14

I am working on Multiclass Classification (4 classes) for Language Task and I am using the BERT model for classification task. I am following this blog post Transfer Learning for NLP: Fine-Tuning BERT for Text Classification. My BERT Fine Tuned model returns nn.LogSoftmax(dim=1).

My data is pretty imbalanced so I used sklearn.utils.class_weight.compute_class_weight to compute weights of the classes and used the weights inside the Loss.

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

My results were not so good so I thought of Experementing with Focal Loss and have a code for Focal Loss.

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss

I have 3 questions now. First and the Most important is

  1. Should I use Class Weight with Focal Loss?
  2. If I have to Implement weights inside this Focal Loss, can I use weights parameters inside nn.CrossEntropyLoss()
  3. If this implement is incorrect, what should be the proper code for this one including the weights (if possible)
Rapallo answered 9/11, 2020 at 11:53 Comment(2)
wait, if your data is imbalanced, why did you pick 'balanced' here? I'm rather confused compute_class_weight('balanced', np.unique(train_labels), train_labels)Riorsson
@MonaJalal balanced means assigning the class weight according to the Number of samples present per class? Isn't it? As given in this documentation If ‘balanced’, class weights will be given by n_samples / (n_classes * np.bincount(y)).Rapallo
M
5

I think OP would've gotten his answer by now. I am writing this for other people who might ponder upon this.

There in one problem in OPs implementation of Focal Loss:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

In this line, the same alpha value is multiplied with every class output probability i.e. (pt). Additionally, code doesn't show how we get pt. A very good implementation of Focal Loss could be find in What is Focal Loss and when should you use it. But this implementation is only for binary classification as it has alpha and 1-alpha for two classes in self.alpha tensor.

In case of multi-class classification or multi-label classification, self.alpha tensor should contain number of elements equal to the total number of labels. The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).

Masseur answered 18/10, 2021 at 18:54 Comment(0)
E
4

You may find answers to your questions as follows:

  1. Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss. The alpha and gamma factors handle the class imbalance in the focal loss equation.
  2. No need of extra weights because focal loss handles them using alpha and gamma modulating factors
  3. The implementation you mentioned is correct according to the focal loss formula but I had trouble in causing my model to converge with this version hence, I used the following implementation from mmdetection framework
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

You can also experiment with another focal loss version available

Ecclesiastic answered 4/1, 2021 at 14:36 Comment(1)
suppose to be softmax ? its multiclassGalactopoietic
C
3

I think the implementation in your question is wrong. The alpha is the class weight.

In cross entropy the class weight is the alpha_t as shown in the following expression:

enter image description here

you see that it is alpha_t rather than alpha.

In focal loss the fomular is
enter image description here

and we can see from this popular Pytorch implementation the alpha acts the same way as class weight.

References:

  1. https://amaarora.github.io/2020/06/29/FocalLoss.html#alpha-and-gamma
  2. https://github.com/clcarwin/focal_loss_pytorch
Carmichael answered 9/12, 2021 at 16:6 Comment(0)
D
2

I was searching for this myself and found most implementations way to cumbersome. One can use pytorch's CrossEntropyLoss instead (and use ignore_index) and add the focal term. Keep in mind that class weights need to be applied after getting pt from CE so they must be applied separately rather than in CE as weights=alpha

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, ignore_index=-100, reduction='mean'):
        super().__init__()
        # use standard CE loss without reducion as basis
        self.CE = nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        '''
        input (B, N)
        target (B)
        '''
        minus_logpt = self.CE(input, target)
        pt = torch.exp(-minus_logpt) # don't forget the minus here
        focal_loss = (1-pt)**self.gamma * minus_logpt

        # apply class weights
        if self.alpha != None:
            focal_loss *= self.alpha.gather(0, target)
        
        if self.reduction == 'mean':
            focal_loss = focal_loss.mean()
        elif self.reduction == 'sum':
            focal_loss = focal_loss.sum()
        return focal_loss
Dugald answered 17/10, 2023 at 7:34 Comment(0)
Y
0

I try to implement it based on a weight computed by compute_class_weight by sklearn. And I think my code could extend to multiclass by changing F.nll_loss to entropy loss.

class WeightedFocalLoss(nn.Module):

def __init__(self, alpha, gamma=2):
    super(WeightedFocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

def forward(self, inputs, targets):

    BCE_loss = F.nll_loss(inputs, targets, reduction='none')
    targets = targets.type(torch.long)
    # at = self.alpha.gather(0, targets.data.view(-1))
    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha[targets]*(1-pt)**self.gamma * BCE_loss
    loss_weighted_manual = F_loss.sum() / self.alpha[targets].sum()
    return loss_weighted_manual
Yabber answered 23/7 at 3:58 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.