How to give target_transform
a function for changing the labels to onehot encoding?
For example, the MNIST dataset in torchvision:
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/',
train=True,
download=True,
transform=train_transform,
target_transform=<????>)
Tried F.onehot()
but it didn't work.
torch.nn.functional.one_hot(torch.tensor(2),5).type(torch.cuda.FloatTensor)
– Amaral