How to get top k accuracy in semantic segmentation using PyTorch?
Asked Answered
D

2

6

How do you compute the top k accuracy in semantic segmentation? In classification, we might compute the topk accuracy as:

correct = output.eq(gt.view(1, -1).expand_as(output))
Dikmen answered 25/12, 2019 at 3:29 Comment(1)
for normal classification you can check this: discuss.pytorch.org/t/imagenet-example-accuracy-calculation/… does that help? How is it different for segmentation?Understructure
L
3

You are looking for torch.topk function that computes the top k values along a dimension.
The second output of torch.topk is the "arg top k": the k indices of the top values.

Here's how this can be used in the context of semantic segmentation:
Suppose you have the ground truth prediction tensor y of shape b-h-w (dtype=torch.int64).
Your model predicts per-pixel class logits of shape b-c-h-w, with c is the number of classes (including "background"). These logits are the "raw" predictions before softmax function transforms them into class probabilities. Since we are only looking at the top k, it does not matter if the predictions are "raw" or "probabilities".

# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()  

Note that this method does not take into account "ignore" pixels. This can be done with a slight modification to the above code:

valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
Leadin answered 25/12, 2019 at 9:52 Comment(0)
K
0

Assuming your output is a series of scores ordered as per your list of classes labels:

import torch

scores, indices = torch.topk(output, k)
correct = labels[indices]
Kerk answered 12/3, 2021 at 9:10 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.