How to transform labels in pytorch to onehot
Asked Answered
F

2

1

How to give target_transform a function for changing the labels to onehot encoding?

For example, the MNIST dataset in torchvision:

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                           train=True,
                                           download=True,
                                           transform=train_transform,
                                           target_transform=<????>)

Tried F.onehot() but it didn't work.

Flaviaflavian answered 10/8, 2020 at 14:25 Comment(1)
whats the issue with F.onehot() ? its working fine for me as torch.nn.functional.one_hot(torch.tensor(2),5).type(torch.cuda.FloatTensor)Amaral
G
2

This is how I implemented it. Not sure if there's a cleaner way.

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
                                 transform=torchvision.transforms.ToTensor(),
                                 target_transform=torchvision.transforms.Compose([
                                 lambda x:torch.LongTensor([x]), # or just torch.tensor
                                 lambda x:F.one_hot(x,10)]),
                                 download=True)
  • It needs to be an index tensor? i.e. int64

  • Can't use torchvision.ToTensor because it's not an image

  • Also torch.LongTensor and torch.tensor behave differently with int input

  • Need to provide number of classes

Glennieglennis answered 3/4, 2021 at 9:39 Comment(0)
L
1

Use lambda user-defined function to turn the integer into a one-hot encoded tensor.

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', train=True, 
    download=True, transform=train_transform, 
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
  • It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls scatter_ which assigns a value=1 on the index as given by the label y.
Lucy answered 15/6, 2021 at 15:59 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.