How to create batches using PyTorch DataLoader such that each example in a given batch has the same value for an attribute?
Asked Answered
H

2

5

Suppose I have a list, datalist which contains several examples (which are of type torch_geometric.data.Data for my use case). Each example has an attribute num_nodes

For demo purpose, such datalist can be created using the following snippet of code

import torch
from torch_geometric.data import Data # each example is of this type
import networkx as nx # for creating random data
import numpy as np
# the python list containing the examples
datalist = []
for num_node in [9, 11]:
    for _ in range(1024):
        edge_index = torch.from_numpy(
            np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
        ).t().contiguous()
        datalist.append(
            Data(
                x=torch.rand(num_node, 5), 
                edge_index=edge_index, 
                edge_attr=torch.rand(edge_index.size(1))
            )
        )

From the above datalist object, I can create a torch_geometric.loader.DataLoader (which subclasses torch.utils.data.DataLoader) naively (without any constraints) by using the DataLoader constructor as:

from torch_geometric.loader import DataLoader
dataloader = DataLoader(
    datalist, batch_size=128, shuffle=True
)

My question is, how can I use the DataLoader class to ensure that each example in a given batch has the same value for num_nodes attribute?

PS: I tried to solve it and came up with a hacky solution by combining multiple DataLoader objects using the combine_iterators function snippet from here as follows:

def get_combined_iterator(*iterables):
    nexts = [iter(iterable).__next__ for iterable in iterables]
    while nexts:
        next = random.choice(nexts)
        try:
            yield next()
        except StopIteration:
            nexts.remove(next)

datalists = defaultdict(list)
for data in datalist:
    datalists[data.num_nodes].append(data)
dataloaders = (
    DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)

But, I think that there must be some elegant/better method of doing it, hence this question.

Hamfurd answered 26/1, 2022 at 14:17 Comment(0)
E
3

If your underlying dataset is map-style, you can use define a torch.utils.data.Sampler which returns the indices of the examples you want to batch together. An instance of this will be passed as a batch_sampler kwarg to your DataLoader and you can remove the batch_size kwarg as the sampler will form batches for you depending on how you implement it.

Europe answered 26/1, 2022 at 14:45 Comment(2)
If possible could you please explain how to use torch.utils.data.Sampler with an example?Hamfurd
You subclass the torch.utils.data.Sampler and the __iter__ method returns the indices of the examples you want to form the batch. I don't know PyTorch Geometric so I can't really create an example for you using their APIs.Europe
H
3

Following erip's suggestion, I subclassed torch.utils.data.sampler.Sampler to create a new sampler: BucketSampler which uses torch.utils.data.sampler.SubsetRandomSampler and torch.utils.data.sampler.BatchSampler to achieve batching of examples which have the same value for a given attribute.

import torch
from torch.utils.data.sampler import Sampler, BatchSampler, SubsetRandomSampler
class BucketSampler(Sampler):
    def __init__(self, dataset, batch_size, start_pos_data, generator=None) -> None:
        self.dataset = dataset
        self.batch_size = batch_size
        self.generator = generator
        start_pos_data = start_pos_data
        start_end_indices = []
        for i in range(len(start_pos_data) - 1):
            start_end_indices.append((start_pos_data[i], start_pos_data[i+1]))
        start_end_indices.append((start_pos_data[-1], len(self.dataset)))
        ranges  = [range(start, end) for start, end in start_end_indices]
        subset_samplers = [SubsetRandomSampler(range_, generator=generator) for range_ in ranges]
        self.samplers = [
            BatchSampler(subset_sampler, batch_size, drop_last=False) for subset_sampler in subset_samplers
        ]
        self._len = 0
        for sampler in self.samplers:
            self._len += len(sampler)
        
    def __iter__(self):
        iterators = [iter(sampler) for sampler in self.samplers]
        while iterators:
            randint = torch.randint(0, len(iterators),size=(1,), generator=self.generator)[0]
            try:
                yield next(iterators[randint])
            except StopIteration:
                iterators.pop(randint)
    def __len__(self):
        return self._len

Apart from the usual arguments, this class also takes start_pos_data as an argument which is a list containing the first index in the datalist (dataset from the example given in question) at which attribute value changes. Thus, for the above example we can create such a list with the help of the following snippet of code:

# sort datalist to ensure that data items with the same number of nodes are grouped together
sorted_datalist = sorted(datalist, key = lambda data: data.num_nodes)
# initialize the start_pos_data by 0
start_pos_data = [0]
for i in range(1,len(sorted_datalist)):
    if sorted_datalist[i].num_nodes != sorted_datalist[i-1].num_nodes:
        # append when the number of nodes changes 
        start_pos_data.append(i)

Now, start_pos_data can be passed to the constructor of BucketSampler initialize sampler

bucketSampler = BucketSampler(sorted_datalist, batch_size = 128, start_pos_data = start_pos_data)

After this, the bucketSampler can be passed to as a kwarg to DataLoader constructor as:

from torch_geometric.loader import DataLoader
dataloader = DataLoader(sorted_datalist, batch_sampler = bucketSampler)

This dataloader (upon iteration) will produce the batches in the desired manner.

Hamfurd answered 26/1, 2022 at 14:17 Comment(0)
E
3

If your underlying dataset is map-style, you can use define a torch.utils.data.Sampler which returns the indices of the examples you want to batch together. An instance of this will be passed as a batch_sampler kwarg to your DataLoader and you can remove the batch_size kwarg as the sampler will form batches for you depending on how you implement it.

Europe answered 26/1, 2022 at 14:45 Comment(2)
If possible could you please explain how to use torch.utils.data.Sampler with an example?Hamfurd
You subclass the torch.utils.data.Sampler and the __iter__ method returns the indices of the examples you want to form the batch. I don't know PyTorch Geometric so I can't really create an example for you using their APIs.Europe

© 2022 - 2024 — McMap. All rights reserved.