I have a map-stype dataset, which is used for instance segmentation tasks. The dataset is very imbalanced, in the sense that some images have only 10 objects while others have up to 1200.
How can I limit the number of objects per batch?
A minimal reproducible example is:
import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
W = 700
H = 1000
def collate_fn(batch) -> tuple:
return tuple(zip(*batch))
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx: int):
"""
returns single sample
"""
# print("idx: ", idx)
# deliberately left dangling
# id = self.image_ids[idx].item()
# image_id = self.image_ids[idx]
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = random.randint(10, 1200)
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
def __iter__(self):
obj_count = 0
batch = []
batches = []
counter = 0
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
obj_count <= obj_count + s
and len(batch) <= self.batch_size - 1
and obj_count + s <= self.num_objs_per_batch
and i < len(self.data_source) - 1
):
# because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
batch.append(i)
obj_count += s
else:
batches.append(batch)
yield batch
obj_count = 0
batch = []
counter += 1
obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)
# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
obj_sums[k] = random.randint(10, 1200)
obj_counts = pd.Series(obj_sums)
train_dataset = SyntheticDataset(image_ids=fake_image_ids)
balanced_sampler = BalancedObjectsSampler(
data_source=obj_counts,
batch_size=batch_size,
num_objs_per_batch=1500,
drop_last=False,
)
data_loader_sampler = torch.utils.data.DataLoader(
train_dataset,
num_workers=workers,
collate_fn=collate_fn,
sampler=balanced_sampler,
)
data_loader_iter = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
collate_fn=collate_fn,
)
Iterating over the balanced_sampler
for i, bal_batch in enumerate(balanced_sampler):
print(f"batch_{i}: ", bal_batch)
yields
batch_0: [0]
batch_1: [2, 3]
batch_2: [5]
batch_3: [7]
batch_4: [9, 10]
batch_5: [12, 13, 14, 15]
batch_6: [17, 18]
batch_7: [20, 21, 22]
batch_8: [24, 25]
batch_9: [27]
batch_10: [29]
batch_11: [31]
batch_12: [33]
batch_13: [35, 36, 37]
batch_14: [39, 40]
batch_15: [42, 43]
batch_16: [45, 46]
batch_17: [48, 49, 50]
batch_18: [52, 53, 54]
batch_19: [56]
batch_20: [58, 59]
batch_21: [61, 62]
batch_22: [64]
batch_23: [66]
batch_24: [68]
batch_25: [70, 71]
batch_26: [73]
batch_27: [75, 76, 77]
batch_28: [79, 80]
batch_29: [82, 83, 84, 85, 86, 87]
batch_30: [89]
batch_31: [91]
batch_32: [93, 94]
batch_33: [96]
batch_34: [98]
The above displayed values are the images' indices, but could also be the batch index or even the images' ids.
By running
for i, batch in enumerate(data_loader_sampler):
print("__sample__: ", i, len(batch[0]))
One sees that the batch contains a single sample instead of the expected amount.
__sample__: 0 1
__sample__: 1 1
__sample__: 2 1
__sample__: 3 1
__sample__: 4 1
__sample__: 5 1
__sample__: 6 1
__sample__: 7 1
__sample__: 8 1
__sample__: 9 1
__sample__: 10 1
__sample__: 11 1
__sample__: 12 1
__sample__: 13 1
__sample__: 14 1
__sample__: 15 1
__sample__: 16 1
__sample__: 17 1
__sample__: 18 1
__sample__: 19 1
__sample__: 20 1
__sample__: 21 1
__sample__: 22 1
__sample__: 23 1
__sample__: 24 1
__sample__: 25 1
__sample__: 26 1
__sample__: 27 1
__sample__: 28 1
__sample__: 29 1
__sample__: 30 1
__sample__: 31 1
__sample__: 32 1
__sample__: 33 1
__sample__: 34 1
What I am really trying to prevent is the following behavior that arises from
for i, batch in enumerate(data_loader_iter):
print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))
which is
__iter__: 0 2510
__iter__: 1 2060
__iter__: 2 2203
__iter__: 3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
obj = _ForkingPickler.dumps(obj)
File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
if not self._poll(timeout):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.8/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "so.py", line 170, in <module>
for i, batch in enumerate(data_loader_iter):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly
which invariably happens when the number of objects per batch is greater than ~2500.
An immediate workaround would be to set the batch_size
low, I just need a more optimal solution.
workers=0
to get a better traceback? – Pterodactylworkers=0
– Pterodactyldata_loader_sampler
is exactly the same. – Venlo