Pytorch: Weight in cross entropy loss
Asked Answered
I

2

17

I was trying to understand how weight is in CrossEntropyLoss works by a practical example. So I first run as standard PyTorch code and then manually both. But the losses are not the same.

from torch import nn
import torch
softmax=nn.Softmax()
sc=torch.tensor([0.4,0.36])
loss = nn.CrossEntropyLoss(weight=sc)
input = torch.tensor([[3.0,4.0],[6.0,9.0]])
target = torch.tensor([1,0])
output = loss(input, target)
print(output)
>>1.7529

Now for manual Calculation, first softmax the input:

print(softmax(input))
>>
tensor([[0.2689, 0.7311],
        [0.0474, 0.9526]])

and then negetive log of the correct class probality and multiply with the respective weight:

((-math.log(0.7311)*0.36) - (math.log(0.0474)*0.4))/2
>>
0.6662

What I am missing here?

Interrelated answered 24/4, 2020 at 17:29 Comment(0)
B
18

To compute class weight of your classes use sklearn.utils.class_weight.compute_class_weight(class_weight, *, classes, y) read it here
This will return you an array i.e weight.
eg .

x = torch.randn(20, 5) 
y = torch.randint(0, 5, (20,)) # classes
class_weights=class_weight.compute_class_weight('balanced',np.unique(y),y.numpy())
class_weights=torch.tensor(class_weights,dtype=torch.float)
 
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

Then pass it to nn.CrossEntropyLoss's weight variable

criterion = nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

loss = criterion(...)
Bonnibelle answered 4/6, 2021 at 11:50 Comment(7)
should np.unique(y) be ordered ascendingly?Pickmeup
Those will be your number of classes; yes by default it will be ascending orderBonnibelle
Shouldn't you pass 1/class_weights? After all you want to increase the weight for minority classes.Medan
Code returns error "compute_class_weight() takes 1 positional argument but 3 were given"Ottilie
@IsaacZhao you need to pass arguments implicitly compute_class_weight('balanced', classes=np.unique(y), y=y.numpy())Yellows
@AlaaM. sklearn.utils.class_weight.compute_class_weight actually calculates inverse frequency as you expected: len(y) / np.bincount(y) -- so minority class gets a higher weightNinnetta
@Ninnetta - You are right! (but to be accurate, they further normalize by n_classes (number of classes))Medan
I
7

For any weighted loss (reduction='mean'), the loss will be normalized by the sum of the weights. So in this case:

((-math.log(0.7311)*0.36) - (math.log(0.0474)*0.4))/(.4+.36)
>> 1.7531671457872036
Interrelated answered 24/4, 2020 at 19:13 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.