'DataLoader' object does not support indexing
Asked Answered
G

4

5

I have downloaded the ImageNet dataset via this pytorch api by setting download=True. But I cannot iterate through the dataloader.

The error says "'DataLoader' object does not support indexing"

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)

I tried a simple approach I just tried to run the following,

trainloader[0]

In the root directory, the pattern is

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg

The docs in the official website doesnt say anything else. https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet

What am I doing wrong ?

Greegree answered 1/7, 2019 at 15:27 Comment(2)
You are creating DataLoader from DataLoader in your example, is it a mistake or your real code?Syman
Yes that is the real codeGreegree
G
5

Solution

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  


# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)


for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break
Greegree answered 2/7, 2019 at 15:42 Comment(0)
S
9

Well, the answer is pretty simple (besides error mentioned in the other answer).

DataLoader has no __getitem__ method (see in the source code for yourself).

It is used for iterating, not random access, over data (or batches of data). If you want to access specific element you should use torch.utils.data.Dataset, in your case:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]

Getting a batch

If you want to get a batch you may iterate over it and break afterwards:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break

DataLoader creates random indices in default or specified way (see samplers), hence there is no __getitem__ as it wouldn't make sense for this object.

You may also inherit from the DataLoader and create your own __getitem__ function doing what you want (more complicated though).

Full example

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break

Above should print the first batch whatever is inside.

Syman answered 2/7, 2019 at 10:48 Comment(7)
Then how do I get it like a batch ?Greegree
Just to clarify, what do you mean by 'dataloader' according to Anubhav Singh's code is it the 'trainset' or 'trainloader' ? Cause with trainloader it doesnt work !Greegree
It works if you create DataLoader from train_dataset. dataloader refers to instance of such DataLoader class.Syman
I am really getting confused here. Can you please explain or just update your codesGreegree
Well yes and no. The problem I was facing was the img being a PIL file and the debugger doesnt say it. Your code still doesnt run without errors. The fix is add "standard_transforms.ToTensor()," as transforms. :)Greegree
Plus for some reason increasing the batch size from 1 to anything else doesnt workGreegree
Yeah, I assumed correct transformations and print is a simple example to get the point across.Syman
G
5

Solution

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  


# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)


for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break
Greegree answered 2/7, 2019 at 15:42 Comment(0)
W
1

The input dataset to torch.utils.data.DataLoader() should be of type torch.utils.data.Dataset, not torch.utils.data.DataLoader, which is what you are doing in above code.

So, your above code should be:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', 
                                          split='train', 
                                          download=False)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=1, 
                                          shuffle=False, 
                                          num_workers=1)

For more details, check the official torch documentation here.

Wicketkeeper answered 1/7, 2019 at 20:21 Comment(2)
Yes I see the problem and I tried your solution. I still have the same error "'DataLoader' object does not support indexing" when I do "trainloader[0]"Greegree
While true, it does not solve the issue (leaving alone the fact of reiterating the comment).Syman
R
0

I ended up with this dirty solution:

def Dataloader_by_Index(data_loader, target=0):
    for index, data in enumerate(data_loader):
        if index == target:
            return data
    return None
fifth_element = Dataloader_by_Index(my_data_loader, target=4)
Ravelment answered 11/5, 2023 at 2:20 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.