RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
Asked Answered
F

1

22

This is the error i get when I try to train my network.

The class we used to store Images from the Caltech 101 dataset was provided us by our teachers.

from torchvision.datasets import VisionDataset

from PIL import Image

import os
import os.path
import sys


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class Caltech(VisionDataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)

        self.split = split # This defines the split you are going to use
                           # (split files are called 'train.txt' and 'test.txt')

        '''
        - Here you should implement the logic for reading the splits files and accessing elements
        - If the RAM size allows it, it is faster to store all data in memory
        - PyTorch Dataset classes use indexes to read elements
        - You should provide a way for the __getitem__ method to access the image-label pair
          through the index
        - Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class) 
        '''
        # Open file in read only mode and read all lines
        file = open(self.split, "r")
        lines = file.readlines()

        # Filter out the lines which start with 'BACKGROUND_Google' as asked in the homework
        self.elements = [i for i in lines if not i.startswith('BACKGROUND_Google')]

        # Delete BACKGROUND_Google class from dataset labels
        self.classes = sorted(os.listdir(os.path.join(self.root, "")))
        self.classes.remove("BACKGROUND_Google")


    def __getitem__(self, index):
        ''' 
        __getitem__ should access an element through its index
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        '''

        img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))

        target = self.classes.index(self.elements[index].rstrip().split('/')[0])

        image, label = img, target # Provide a way to access image and label via index
                           # Image should be a PIL Image
                           # label can be int

        # Applies preprocessing when accessing the image
        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        '''
        The __len__ method returns the length of the dataset
        It is mandatory, as this is used by several other components
        '''
        # Provides a way to get the length (number of elements) of the dataset
        length =  len(self.elements)
        return length

Whereas the preprocessing phase is done by this code:

# Define transforms for training phase
train_transform = transforms.Compose([transforms.Resize(256),      # Resizes short size of the PIL image to 256
                                      transforms.CenterCrop(224),  # Crops a central square patch of the image
                                                                   # 224 because torchvision's AlexNet needs a 224x224 input!
                                                                   # Remember this when applying different transformations, otherwise you get an error
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalizes tensor with mean and standard deviation
])
# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))   

In the end this is the preparation of datasets and dataloader:

# Clone github repository with data
if not os.path.isdir('./Homework2-Caltech101'):
  !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git

# Commands to execute when there is an error saying no file or directory related to ./Homework2-Caltech101/
# !rm -r ./Homework2-Caltech101/
# !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git

DATA_DIR = 'Homework2-Caltech101/101_ObjectCategories'
SPLIT_TRAIN = 'Homework2-Caltech101/train.txt'
SPLIT_TEST = 'Homework2-Caltech101/test.txt'


# 1 - Data preparation
myTrainDS = Caltech(DATA_DIR, split = SPLIT_TRAIN, transform=train_transform)
myTestDS = Caltech(DATA_DIR, split = SPLIT_TEST, transform=eval_transform)

print('My Train DS: {}'.format(len(myTrainDS)))
print('My Test DS: {}'.format(len(myTestDS)))

# 1 - Data preparation
myTrain_dataloader = DataLoader(myTrainDS, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
myTest_dataloader = DataLoader(myTestDS, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

Okay now the two .txt files contain the lists of images we want to have in the train and test splits, so we have to get them from there, but that should have been done correctly. The thing is that when I approach my training phase (see code later) I am presented the error in the title. I already tried to add the following line in the transform function:

[...]
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),

after the centercrop, but it says that Image has no attribute repeat, so I'm kinda stuck.

The training code line which gives me the error is the following:

# Iterate over the dataset
  for images, labels in myTrain_dataloader:

If needed, full error is:

RuntimeError                              Traceback (most recent call last)

<ipython-input-197-0e4710a9855d> in <module>()
     47 
     48   # Iterate over the dataset
---> 49   for images, labels in myTrain_dataloader:
     50 
     51     # Bring data over the device of choice

2 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    817             else:
    818                 del self._task_info[idx]
--> 819                 return self._process_data(data)
    820 
    821     next = __next__  # Python 2 compatibility

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
    844         self._try_put_index()
    845         if isinstance(data, ExceptionWrapper):
--> 846             data.reraise()
    847         return data
    848 

/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
    383             # (https://bugs.python.org/issue2651), so we work around it.
    384             msg = KeyErrorMessage(msg)
--> 385         raise self.exc_type(msg)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-180-0b00b175e18c>", line 72, in __getitem__
    image = self.transform(image)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 70, in __call__
    img = t(img)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 175, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 217, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

I'm using Alexnet and the code I was provided is the following:

    net = alexnet() # Loading AlexNet model

# AlexNet has 1000 output neurons, corresponding to the 1000 ImageNet's classes
# We need 101 outputs for Caltech-101
net.classifier[6] = nn.Linear(4096, NUM_CLASSES) # nn.Linear in pytorch is a fully connected layer
                                                 # The convolutional layer is nn.Conv2d

# We just changed the last layer of AlexNet with a new fully connected layer with 101 outputs
# It is mandatory to study torchvision.models.alexnet source code
Flanagan answered 6/12, 2019 at 18:43 Comment(0)
H
29

The first dimension of the tensor means the color, so what your error means is that you are giving a grayscale picture (1 channel), while the data loader expects a RGB image (3 channels). You defined a pil_loader function that returns an image in RGB, but you are never using it.

So you have two options:

  1. Work with the image in Grayscale instead of rgb, which is cheaper computationally speaking. Solution: Both in train and test transforms change transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) for transforms.Normalize((0.5), (0.5))

  2. Make sure your image is in rgb. I don't know how your images are stored, but I guess you downloaded the dataset in grayscale. One thing you could try is using the pil_loader function you defines. Try changing img = Image.open(os.path.join(self.root, self.elements[index].rstrip())) for img = pil_loader(os.path.join(self.root, self.elements[index].rstrip())) in yout __getitem__ function.

Let me know how it goes!

Heaps answered 7/12, 2019 at 3:34 Comment(4)
Solution two worked perfectly! Thank you very much! but what is the difference between pil_loader and open?Flanagan
pil_loader is opening the image in color, while image.open is reading it as grayscale.Heaps
solution 1. was perfect for me. As earlier I converted RGB to grayscaleTeodoor
Good answer. Solution 1 fixed it.Selfcongratulation

© 2022 - 2024 — McMap. All rights reserved.