Suppose I have a list, datalist
which contains several examples (which are of type torch_geometric.data.Data
for my use case). Each example has an attribute num_nodes
For demo purpose, such datalist
can be created using the following snippet of code
import torch
from torch_geometric.data import Data # each example is of this type
import networkx as nx # for creating random data
import numpy as np
# the python list containing the examples
datalist = []
for num_node in [9, 11]:
for _ in range(1024):
edge_index = torch.from_numpy(
np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
).t().contiguous()
datalist.append(
Data(
x=torch.rand(num_node, 5),
edge_index=edge_index,
edge_attr=torch.rand(edge_index.size(1))
)
)
From the above datalist
object, I can create a torch_geometric.loader.DataLoader
(which subclasses torch.utils.data.DataLoader
) naively (without any constraints) by using the DataLoader constructor as:
from torch_geometric.loader import DataLoader
dataloader = DataLoader(
datalist, batch_size=128, shuffle=True
)
My question is, how can I use the DataLoader
class to ensure that each example in a given batch has the same value for num_nodes
attribute?
PS:
I tried to solve it and came up with a hacky solution by combining multiple DataLoader
objects using the combine_iterators
function snippet from here as follows:
def get_combined_iterator(*iterables):
nexts = [iter(iterable).__next__ for iterable in iterables]
while nexts:
next = random.choice(nexts)
try:
yield next()
except StopIteration:
nexts.remove(next)
datalists = defaultdict(list)
for data in datalist:
datalists[data.num_nodes].append(data)
dataloaders = (
DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)
But, I think that there must be some elegant/better method of doing it, hence this question.
torch.utils.data.Sampler
with an example? – Hamfurd