Correct way to use custom weight maps in unet architecture
Asked Answered
N

2

10

There is a famous trick in u-net architecture to use custom weight maps to increase accuracy. Below are the details of it:

enter image description here

Now, by asking here and at multiple other place, I get to know about 2 approaches. I want to know which one is correct or is there any other right approach which is more correct?

  1. First is to use torch.nn.Functional method in the training loop:

    loss = torch.nn.functional.cross_entropy(output, target, w) where w will be the calculated custom weight.

  2. Second is to use reduction='none' in the calling of loss function outside the training loop criterion = torch.nn.CrossEntropy(reduction='none')

    and then in the training loop multiplying with the custom weight:

    gt # Ground truth, format torch.long
    pd # Network output
    W # per-element weighting based on the distance map from UNet
    loss = criterion(pd, gt)
    loss = W*loss # Ensure that weights are scaled appropriately
    loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
    loss = torch.mean(loss) # Average across a batch
    

Now, I am kinda confused which one is right or is there any other way, or both are right?

Neoterism answered 14/10, 2019 at 13:28 Comment(0)
P
5

The weighting portion looks like just simply weighted cross entropy which is performed like this for the number of classes (2 in the example below).

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

EDIT:

Have you seen this implementation from Patrick Black?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()
Possing answered 16/12, 2019 at 23:18 Comment(4)
The thing is weight is calculated by a certain function here and is not discreet .for more information, here is a paper - arxiv.org/abs/1505.04597Neoterism
@Mark oh I see now. So it is a pixelwise loss output. And the borders are pre-computed using some library like opencv or something, and then those pixel positions are saved for each image and then multiplied by the loss tensors later on during training so that the algorithm focuses on reducing loss in those areas.Possing
Thanks.this legit looks like an answer,I'll try verifying and implementing it more and will accept your answer after it.Neoterism
Can you explain the intuition behind this line logp = logp.gather(1, target.view(batch_size, 1, H, W))Neoterism
R
0

Note that torch.nn.CrossEntropyLoss() is a class that calls torch.nn.functional. See https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

You can use the weights when you define the criteria. Comparing them functionally, both methods are the same.

Now, I do not understand your idea of computing loss inside the training loop in method 1 and outside the training loop in method 2. if you compute loss outside the loop then how will you backpropagate?

Rattish answered 13/12, 2019 at 16:45 Comment(5)
I was not confused between using torch.nn.CrossEntropyLoss() and torch.nn.functional.cross_entropy(output, target, w), I was confused how to use custom weight maps in the loss.Please see this paper - arxiv.org/abs/1505.04597 and let me know, if you are still not able to figure out what I am askingNeoterism
If I understand it correctly, I think method 2 is the right one. The weights (w) inside the loss torch.nn.functional.cross_entropy(output, target, w) are weights for classes not w(x) in the formula. We can easily test it with a small script.Rattish
Yep, even I am reaching to the same conclusion.I'll revert to you back if my network runs as expected and will mark the answer as accepted.Neoterism
okay, its not working.I am getting grad can be implicitly created only for scalar outputs when I run loss = loss*w methodNeoterism
Are you sure you are summing them up or taking the mean?Rattish

© 2022 - 2024 — McMap. All rights reserved.