Output and Broadcast shape mismatch in MNIST, torchvision
Asked Answered
E

1

19

I am getting following error when using MNIST dataset in Torchvision

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

Here is my code:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
Emancipator answered 12/3, 2019 at 14:51 Comment(1)
MNIST dataset has only 1 channel. You need to change transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) (it is for 3 channels)Selway
E
49

The error is due to color vs grayscale on the dataset, the dataset is grayscale.

I fixed it by changing transform to

transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])
Emancipator answered 12/3, 2019 at 14:57 Comment(3)
Thanks it worked. Can you just briefly explain your solution. Why did this work?Rearrange
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 3 values means 3 channels, however for mnist there is only 1 channel, hence transforms.Normalize((0.5,), (0.5,))Emancipator
Thanks Jibin, that makes senseRearrange

© 2022 - 2024 — McMap. All rights reserved.