Problem with Dataloader object not subscriptable
Asked Answered
G

3

8

I am now running a Python program using Pytorch. I use my own dataset, not torch.data.dataset. I download data from a pickle file extracted from feature extraction. But the following errors appear:

Traceback (most recent call last):
  File "C:\Users\hp\Downloads\efficient_densenet_pytorch-master\demo-emotion.py", line 326, in <module>
    fire.Fire(demo)
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\fire\core.py", line 138, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\fire\core.py", line 468, in _Fire
    target=component.__name__)
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\fire\core.py", line 672, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "C:\Users\hp\Downloads\efficient_densenet_pytorch-master\demo-emotion.py", line 304, in demo
    train(model,train_set1, valid_set=valid_set, test_set=test1, save=save, n_epochs=n_epochs,batch_size=batch_size,seed=seed)
  File "C:\Users\hp\Downloads\efficient_densenet_pytorch-master\demo-emotion.py", line 172, in train
    n_epochs=n_epochs,
  File "C:\Users\hp\Downloads\efficient_densenet_pytorch-master\demo-emotion.py", line 37, in train_epoch
    loader=np.asarray(list(loader))
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
    data = self._next_data()
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\torch\utils\data\dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\hp\Anaconda3\envs\tf-gpu\lib\site-packages\torch\utils\data\dataset.py", line 257, in __getitem__
    return self.dataset[self.indices[idx]]
TypeError: 'DataLoader' object is not subscriptable

The code is:

train_set1 = Owndata()

train1, test1 = train_set1 .get_splits()
# prepare data loaders
train_dl = torch.utils.data.DataLoader(train1, batch_size=32, shuffle=True)
test_dl =torch.utils.data.DataLoader(test1, batch_size=1024, shuffle=False)
test_set1 = Owndata()
'''print('test_set# ',test_set)'''  
if valid_size:
    valid_set = Owndata()
    indices = torch.randperm(len(train_set1))
    train_indices = indices[:len(indices) - valid_size]
    valid_indices = indices[len(indices) - valid_size:]
    train_set1 = torch.utils.data.Subset(train_dl, train_indices)
    valid_set = torch.utils.data.Subset(valid_set, valid_indices)
else:
    valid_set = None
model = DenseNet(
    growth_rate=growth_rate,
    block_config=block_config,
    num_classes=10,
    small_inputs=True,
    efficient=efficient,
)
train(model,train_set1, valid_set=valid_set, test_set=test1, save=save, n_epochs=n_epochs, batch_size=batch_size, seed=seed)

Any help is appreciated! Thanks a lot in advance!!

Geter answered 2/5, 2020 at 16:10 Comment(0)
G
20

It is not the line giving you an error as it's the very last train function you are not showing.

You are confusing two things:

  • torch.utils.data.Dataset object is indexable (dataset[5] works fine for example). It is a simple object which defines how to get a single (usually single) sample of data.
  • torch.utils.data.DataLoader - non-indexable, only iterable, usually returns batches of data from above Dataset. Can work in parallel using num_workers. It's what you are trying to index while you should use dataset for that.

Please see PyTorch documentation about data to get a better grasp on how those work.

Garonne answered 2/5, 2020 at 19:7 Comment(0)
C
0

Hope that everyone like me learning PyTorch can solve this problem

Try use this one

img, label = next(iter(dataloder))

It equals to imgs, label = dataloder[0]

If you want to loop the dataloader instead of print one of the image and label, try the following one

for data in dataloder:
    imgs, target = data
Coaptation answered 30/4 at 14:28 Comment(0)
B
0

Let's say dataloader is defined over the dataset with batch size of 4:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

For small dataset, if you want to get the i-th batch of images and labels from the dataloader, one can do:

images, labels = list(dataloader)[i]

For large dataset, one can do:

dataiter = iter(dataloader)
images, labels = next(x for j,x in enumerate(dataiter) if j==i)
Bogard answered 19/8 at 20:23 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.