I saw a sudoku solver CNN uses a sparse categorical cross-entropy as a loss function using the TensorFlow framework, I am wondering if there is a similar function for Pytorch? if not could how could I potentially calculate the loss of a 2d array using Pytorch?
Is there a version of sparse categorical cross entropy in pytorch?
Asked Answered
Here is an example of usage of nn.CrossEntropyLoss for image segmentation with a batch of size 1, width 2, height 2 and 3 classes.
Image segmentation is a classification problem at pixel level. Of course you can also use nn.CrossEntropyLoss for basic image classification as well.
The sudoku problem in the question can be seen as an image segmentation problem where you have 10 classes (the 10 digits) (though Neural Networks are not appropriate to solve combinatorial problems like Sudoku which already have efficient exact resolution algorithms).
nn.CrossEntropyLoss accepts ground truth labels directly as integers in [0, N_CLASSES[ (no need to onehot encode the labels):
import torch
from torch import nn
import numpy as np
# logits predicted
x = np.array([[
[[1,0,0],[1,0,0]], # predict class 0 for pixel (0,0) and class 0 for pixel (0,1)
[[0,1,0],[0,0,1]], # predict class 1 for pixel (1,0) and class 2 for pixel (1,1)
]])*5 # multiply by 5 to give bigger losses
print("logits map :")
print(x)
# ground truth labels
y = np.array([[
[0,1], # must predict class 0 for pixel (0,0) and class 1 for pixel (0,1)
[1,2], # must predict class 1 for pixel (1,0) and class 2 for pixel (1,1)
]])
print("\nlabels map :")
print(y)
x=torch.Tensor(x).permute((0,3,1,2)) # shape of preds must be (N, C, H, W) instead of (N, H, W, C)
y=torch.Tensor(y).long() # shape of labels must be (N, H, W) and type must be long integer
losses = nn.CrossEntropyLoss(reduction="none")(x, y) # reduction="none" to get the loss by pixel
print("\nLosses map :")
print(losses)
# notice that the loss is big only for pixel (0,1) where we predicted 0 instead of 1
© 2022 - 2024 — McMap. All rights reserved.
nn.CrossEntropyLoss
is sparse categorical cross-entropy (i.e. it takes integers as targets instead of one-hot vectors). – Menon[B, 9, 81]
tensor of logits and the targets would be a tensor of shape[B, 81]
containing integers between 0 and 8 inclusive (corresponding to the marks 1 through 9 respectively) whereB
is batch size. – Menon