Problem Description
I tried to load image data using a PyTorch custom dataset, however, I received the error message listed below. After its occurrence, I checked the data and found that my image set consists of 2 types of shape (512,512,3) and (1024,1024). My assumption is that the error is related to this.
Note: The code is able to read some of the images but throws the error message for others.
Questions
How should one preprocess such image data for training?
Are there any other reasons for the error message?
Error message
KeyError Traceback (most recent call last)
<ipython-input-163-aa3385de8026> in <module>
----> 1 train_features, train_labels = next(iter(train_dataloader))
2 print(f"Feature batch shape: {train_features.size()}")
3 print(f"Labels batch shape: {train_labels.size()}")
4 img = train_features[0].squeeze()
5 label = train_labels[0]
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils /data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1201 else:
1202 del self._task_info[idx]
1203 return self._process_data(data)
1204
1205 def _try_put_index(self):
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1227 self._try_put_index()
1228 if isinstance(data, ExceptionWrapper):
1229 data.reraise()
1230 return data
1231
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
423 # have message field
424 raise self.exc_type(message=msg)
425 raise self.exc_type(msg)
426
427
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas /core/indexes/base.py", line 2898, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 101, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1032, in pandas._libs.hashtable.Int64HashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1039, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 16481
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-161-f38b78d77dcb>", line 19, in __getitem__
img_path =os.path.join(self.img_dir,self.image_ids[idx])
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 882, in __getitem__
return self._get_value(key)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 990, in _get_value
loc = self.index.get_loc(label)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2900, in get_loc
raise KeyError(key) from err
KeyError: 16481
Code
from torchvision.io import read_image
import torch
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, dataset, transforms=None, target_transforms=None):
#self.train_data = pd.read_csv("Data/train_data.csv")
self.image_ids = dataset.image_id
self.image_labels = dataset.label
self.img_dir = 'Data/images'
self.transforms = transforms
self.target_transforms = target_transforms
def __len__(self):
return len(self.image_ids)
def __getitem__(self,idx):
# image path
img_path =os.path.join(self.img_dir,self.image_ids[idx])
# image
image = read_image(img_path)
label = self.image_labels[idx]
# transform image
if self.transforms:
image = self.transforms(image)
# transform target
if self.target_transforms:
label = self.target_transforms(label)
return image, label
train_data
is the pandas object of the csv file which has the image id and label information.
from sklearn.model_selection import train_test_split
X_train, X_test = train_test_split(train_data, test_size=0.1, random_state=42)
train_df = CustomImageDataset(X_train)
train_dataloader = torch.utils.data.DataLoader(
train_df,
batch_size=64,
num_workers=1,
shuffle=True)