I need to implement a multi-label image classification model in PyTorch. However my data is not balanced, so I used the WeightedRandomSampler
in PyTorch to create a custom dataloader. But when I iterate through the custom dataloader, I get the error : IndexError: list index out of range
Implemented the following code using this link :https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3?u=surajsubramanian
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
N = float(sum(count))
for i in range(nclasses):
weight_per_class[i] = N/float(count[i])
weight = [0] * len(images)
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
weights = make_weights_for_balanced_classes(train_dataset.imgs, len(full_dataset.classes))
weights = torch.DoubleTensor(weights)
sampler = WeightedRandomSampler(weights, len(weights))
train_loader = DataLoader(train_dataset, batch_size=4,sampler = sampler, pin_memory=True)
Based on the answer in https://mcmap.net/q/1274869/-using-weightedrandomsampler-in-pytorch, the following is my updated code. But then too when I create a dataloader :loader = DataLoader(full_dataset, batch_size=4, sampler=sampler)
, len(loader)
returns 1.
class_counts = [1691, 743, 2278, 1271]
num_samples = np.sum(class_counts)
labels = [tag for _,tag in full_dataset.imgs]
class_weights = [num_samples/class_counts[i] for i in range(len(class_counts)]
weights = [class_weights[labels[i]] for i in range(num_samples)]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), num_samples)
Thanks a lot in advance !
I included an utility function based on the accepted answer below :
def sampler_(dataset):
dataset_counts = imageCount(dataset)
num_samples = sum(dataset_counts)
labels = [tag for _,tag in dataset]
class_weights = [num_samples/dataset_counts[i] for i in range(n_classes)]
weights = [class_weights[labels[i]] for i in range(num_samples)]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
return sampler
The imageCount function finds number of images of each class in the dataset. Each row in the dataset contains the image and the class, so we take the second element in the tuple into consideration.
def imageCount(dataset):
image_count = [0]*(n_classes)
for img in dataset:
image_count[img[1]] += 1
return image_count
loader = DataLoader(dataset, batch_size=4, sampler=sampler)
, the length of my loader as returned bylen(loader)
is 1, – Poky