Pytorch RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
Asked Answered
L

2

43

I use code from here to train a model to predict printed style number from 0 to 9:

idx_to_class = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7:"7", 8: "8", 9:"9"}

def predict(model, test_image_name):
    transform = image_transforms['test']
    test_image = Image.open(test_image_name)
    plt.imshow(test_image)
    test_image_tensor = transform(test_image)
    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        ps = torch.exp(out)
        topk, topclass = ps.topk(1, dim=1)
        # print(topclass.cpu().numpy()[0][0])
        print("Image class:  ", idx_to_class[topclass.cpu().numpy()[0][0]])
predict(model, "path_of_test_image")

But I get an error when try to use predict:

Traceback (most recent call last):

  File "<ipython-input-12-f8636d3ba083>", line 26, in <module>
    predict(model, "/home/x/文档/Deep_Learning/pytorch/MNIST/test/2/QQ截图20191022093955.png")

  File "<ipython-input-12-f8636d3ba083>", line 9, in predict
    test_image_tensor = transform(test_image)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 166, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)

  File "/home/x/.local/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 217, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

How could I fix it? Thanks.

Levant answered 22/10, 2019 at 3:53 Comment(0)
D
114

I suspect your test_image has an additional alpha channel per pixel, thus it has 4 channels instead of only three.
Try:

test_image = Image.open(test_image_name).convert('RGB')
Drum answered 22/10, 2019 at 5:10 Comment(4)
does this means we can not use .png file with transparency ?Iphlgenia
@Iphlgenia How do you want to use transparency information?Drum
I mean to say .png because R G B A are four channel. If Image.open(test_image_name).convert('RGB') just extracts pixels and creates required image without background or transparency - ok.Iphlgenia
@Iphlgenia when using .convert('RGB') you can safely use PNG files that might contain alpha channels without worrying it will break your code.Drum
M
0

Shai's solution for those who are using torchvision.io.read_image:

from torchvision.io import read_image, ImageReadMode
test_image = read_image(test_image_path, mode=ImageReadMode.RGB)
Myotome answered 20/3, 2023 at 10:19 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.