Using WeightedRandomSampler in PyTorch
Asked Answered
P

3

8

I need to implement a multi-label image classification model in PyTorch. However my data is not balanced, so I used the WeightedRandomSampler in PyTorch to create a custom dataloader. But when I iterate through the custom dataloader, I get the error : IndexError: list index out of range

Implemented the following code using this link :https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3?u=surajsubramanian

def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight 
weights = make_weights_for_balanced_classes(train_dataset.imgs, len(full_dataset.classes))
weights = torch.DoubleTensor(weights)
sampler = WeightedRandomSampler(weights, len(weights))

train_loader = DataLoader(train_dataset, batch_size=4,sampler = sampler, pin_memory=True)   

Based on the answer in https://mcmap.net/q/1274869/-using-weightedrandomsampler-in-pytorch, the following is my updated code. But then too when I create a dataloader :loader = DataLoader(full_dataset, batch_size=4, sampler=sampler), len(loader) returns 1.

class_counts = [1691, 743, 2278, 1271]
num_samples = np.sum(class_counts)
labels = [tag for _,tag in full_dataset.imgs] 

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts)]
weights = [class_weights[labels[i]] for i in range(num_samples)]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), num_samples)

Thanks a lot in advance !

I included an utility function based on the accepted answer below :

def sampler_(dataset):
    dataset_counts = imageCount(dataset)
    num_samples = sum(dataset_counts)
    labels = [tag for _,tag in dataset]

    class_weights = [num_samples/dataset_counts[i] for i in range(n_classes)]
    weights = [class_weights[labels[i]] for i in range(num_samples)]
    sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
    return sampler

The imageCount function finds number of images of each class in the dataset. Each row in the dataset contains the image and the class, so we take the second element in the tuple into consideration.

def imageCount(dataset):
    image_count = [0]*(n_classes)
    for img in dataset:
        image_count[img[1]] += 1
    return image_count
Poky answered 23/3, 2020 at 10:45 Comment(0)
T
8

That code looks a bit complex... You can try the following:

#Let there be 9 samples and 1 sample in class 0 and 1 respectively
class_counts = [9.0, 1.0]
num_samples = sum(class_counts)
labels = [0, 0,..., 0, 1] #corresponding labels of samples

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = [class_weights[labels[i]] for i in range(int(num_samples))]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
Torto answered 23/3, 2020 at 12:23 Comment(6)
I tried your code, but when I include sampler while creating my DataLoader, loader = DataLoader(dataset, batch_size=4, sampler=sampler), the length of my loader as returned by len(loader) is 1,Poky
My bad, I meant to write len not sum in line 3, which messed up a few variables later. The new code should be correct.Torto
Thanks again, I tried the new code as I have updated in my answer, but I get the following error : ValueError: num_samples should be a positive integer value, but got num_samples=5983Poky
Updated, please try againTorto
related, here's the end to end example of this dummy data from pytorch forums.Salzman
Note to self: num_samples == sum(class_counts) == len(labels)Super
P
8

Here is an alternative solution:

import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler

counts = np.bincount(y)
labels_weights = 1. / counts
weights = labels_weights[y]
WeightedRandomSampler(weights, len(weights))

where y is a list of labels corresponding to each sample, has shape (n_samples,) and are encoded [0, ..., n_classes].

weights won't add up to 1, which is ok according to the official docs.

Potboiler answered 4/12, 2020 at 10:46 Comment(0)
M
-1

The previous answers addressed how to do it for single-label classification. For multi-label classification, you will have to do it differently.

Let's say you have 10000 samples, with 10 classes. You want to use WeightedRandomSampler. The weights that you pass to WeightedRandomSampler are the weights for each of those 10000 samples, not the classes. So you will have to calculate the weights for each of those samples by aggregating the class weights for each sample.

Here's one way to do it. This is for one-hot encoded labels:

# Assuming you already have created your train_dataset object which has all the labels stored.

def calc_sample_weights(labels, class_weights):
    # Aggregate weights by `sum`. You may use an other aggregation.
    return sum(labels * class_weights)

# Specify class weights. You can use any of the methods in the other answers to calculate class_weights.
class_weights = np.array([...])

# Create sample weights, i.e. weights for each of the 10000 samples.
sample_weights = [calc_sample_weights(label, class_weights) 
                                      for label in train_dataset.labels)]

# Create WeightedRandomSampler.
weighted_sampler = WeightedRandomSampler(sample_weights, len(train_dataset))

# Create Batch Sampler for retrieving batches of samples
batch_size = 32
batch_sampler = BatchSampler(weighted_sampler, batch_size, drop_last=False)

# Create train dataloader
train_loader = Dataloader(train_dataset, batch_sampler=batch_sampler)

In the above code, we calculate the sample weights by element-wise multiplying the class_weights with the class_labels of each sample, and then aggregating them through a sum operation. So if class weights are [1.0, 0.5, 0] and the label for a sample is one-hot encoded as [1, 0, 1], then the total weight for that sample would be 1.0. You can do a similar thing with labels that are not one-hot encoded by indexing the class_weights with the class_label indices for a sample and then aggregating the weights.

Notice, we also create a BatchSampler. This is because if you're sampling in batches, you shouldn't use the weight_sampler directly. You should use a BatchSampler instead.

Musser answered 5/1 at 2:59 Comment(1)
Your answer could be improved with additional supporting information. Please edit to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers in the help center.Diminution

© 2022 - 2024 — McMap. All rights reserved.