Custom Sampler correct use in Pytorch
Asked Answered
W

1

7

I have a map-stype dataset, which is used for instance segmentation tasks. The dataset is very imbalanced, in the sense that some images have only 10 objects while others have up to 1200.

How can I limit the number of objects per batch?

A minimal reproducible example is:

import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler


np.random.seed(0)
random.seed(0)
torch.manual_seed(0)


W = 700
H = 1000

def collate_fn(batch) -> tuple:
    return tuple(zip(*batch))

class SyntheticDataset(Dataset):
    def __init__(self, image_ids):
        self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.num_classes = 9

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx: int):
        """
            returns single sample
        """
        # print("idx: ", idx)

        # deliberately left dangling
        # id = self.image_ids[idx].item()
        # image_id = self.image_ids[idx]
        image_id = torch.as_tensor(idx)
        image = torch.randint(0, 255, (H, W))

        num_objects = random.randint(10, 1200)
        image = torch.randint(0, 255, (3, H, W))
        masks = torch.randint(0, 255, (num_objects, H, W))

        target = {}
        target["image_id"] = image_id

        areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
        boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
        labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
        iscrowd = torch.zeros(len(labels), dtype=torch.int64)

        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = areas
        target["iscrowd"] = iscrowd
        target["masks"] = masks

        return image, target, image_id


class BalancedObjectsSampler(BatchSampler):
    """Samples either batch_size images or batches num_objs_per_batch objects.

    Args:
        data_source (list): contains tuples of (img_id).
        batch_size (int): batch size.
        num_objs_per_batch (int): number of objects in a batch.
    Return
        yields the batch_ids/image_ids/image_indices

    """

    def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
        self.data_source = data_source
        self.sampler = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_objs_per_batch = num_objs_per_batch
        self.batch_count = math.ceil(len(self.data_source) / self.batch_size)

    def __iter__(self):

        obj_count = 0
        batch = []
        batches = []
        counter = 0
        for i, (k, s) in enumerate(self.data_source.iteritems()):
            if (
                obj_count <= obj_count + s
                and len(batch) <= self.batch_size - 1
                and obj_count + s <= self.num_objs_per_batch
                and i < len(self.data_source) - 1
            ):
                # because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
                batch.append(i)
                obj_count += s
            else:
                batches.append(batch)
                yield batch
                obj_count = 0
                batch = []
            counter += 1


obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)

# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
    obj_sums[k] = random.randint(10, 1200)

obj_counts = pd.Series(obj_sums)

train_dataset = SyntheticDataset(image_ids=fake_image_ids)

balanced_sampler = BalancedObjectsSampler(
    data_source=obj_counts,
    batch_size=batch_size,
    num_objs_per_batch=1500,
    drop_last=False,
)

data_loader_sampler = torch.utils.data.DataLoader(
    train_dataset,
    num_workers=workers,
    collate_fn=collate_fn,
    sampler=balanced_sampler,
)

data_loader_iter = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    collate_fn=collate_fn,
)

Iterating over the balanced_sampler

for i, bal_batch in enumerate(balanced_sampler):
    print(f"batch_{i}: ", bal_batch)

yields

batch_0:  [0]
batch_1:  [2, 3]
batch_2:  [5]
batch_3:  [7]
batch_4:  [9, 10]
batch_5:  [12, 13, 14, 15]
batch_6:  [17, 18]
batch_7:  [20, 21, 22]
batch_8:  [24, 25]
batch_9:  [27]
batch_10:  [29]
batch_11:  [31]
batch_12:  [33]
batch_13:  [35, 36, 37]
batch_14:  [39, 40]
batch_15:  [42, 43]
batch_16:  [45, 46]
batch_17:  [48, 49, 50]
batch_18:  [52, 53, 54]
batch_19:  [56]
batch_20:  [58, 59]
batch_21:  [61, 62]
batch_22:  [64]
batch_23:  [66]
batch_24:  [68]
batch_25:  [70, 71]
batch_26:  [73]
batch_27:  [75, 76, 77]
batch_28:  [79, 80]
batch_29:  [82, 83, 84, 85, 86, 87]
batch_30:  [89]
batch_31:  [91]
batch_32:  [93, 94]
batch_33:  [96]
batch_34:  [98]

The above displayed values are the images' indices, but could also be the batch index or even the images' ids.

By running

for i, batch in enumerate(data_loader_sampler):
    print("__sample__: ", i, len(batch[0]))

One sees that the batch contains a single sample instead of the expected amount.

__sample__:  0 1
__sample__:  1 1
__sample__:  2 1
__sample__:  3 1
__sample__:  4 1
__sample__:  5 1
__sample__:  6 1
__sample__:  7 1
__sample__:  8 1
__sample__:  9 1
__sample__:  10 1
__sample__:  11 1
__sample__:  12 1
__sample__:  13 1
__sample__:  14 1
__sample__:  15 1
__sample__:  16 1
__sample__:  17 1
__sample__:  18 1
__sample__:  19 1
__sample__:  20 1
__sample__:  21 1
__sample__:  22 1
__sample__:  23 1
__sample__:  24 1
__sample__:  25 1
__sample__:  26 1
__sample__:  27 1
__sample__:  28 1
__sample__:  29 1
__sample__:  30 1
__sample__:  31 1
__sample__:  32 1
__sample__:  33 1
__sample__:  34 1

What I am really trying to prevent is the following behavior that arises from

for i, batch in enumerate(data_loader_iter):
    print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))

which is

__iter__:  0 2510
__iter__:  1 2060
__iter__:  2 2203
__iter__:  3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
    fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "so.py", line 170, in <module>
    for i, batch in enumerate(data_loader_iter):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
    idx, data = self._get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
    success, data = self._try_get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly

which invariably happens when the number of objects per batch is greater than ~2500.

An immediate workaround would be to set the batch_size low, I just need a more optimal solution.

Warder answered 16/3, 2022 at 16:17 Comment(4)
Can you set workers=0 to get a better traceback?Pterodactyl
Yes, for debugging only though.Venlo
Please provide the stack trace with workers=0Pterodactyl
Iterating over data_loader_sampler is exactly the same.Venlo
U
5

If what you are trying to solve really is:

ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).

You could try resizing the allocated shared memory with

# mount -o remount,size=<whatever_is_enough>G /dev/shm

However, as this is not always possible, one fix to your problem would be

class SyntheticDataset(Dataset):

    def __init__(self, image_ids):
        self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.num_classes = 9

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, indices):
        worker_info = torch.utils.data.get_worker_info()

        batch = []
        for i in indices:
            sample = self.get_sample(i)
            batch.append(sample)
        gc.collect()
        return batch

    def get_sample(self, idx: int):

        image_id = torch.as_tensor(idx)
        image = torch.randint(0, 255, (H, W))

        num_objects = idx
        image = torch.randint(0, 255, (3, H, W))
        masks = torch.randint(0, 255, (num_objects, H, W))

        target = {}
        target["image_id"] = image_id

        areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
        boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
        labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
        iscrowd = torch.zeros(len(labels), dtype=torch.int64)

        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = areas
        target["iscrowd"] = iscrowd
        target["masks"] = masks

        return image, target, image_id

and

class BalancedObjectsSampler(BatchSampler):
    """Samples either batch_size images or batches num_objs_per_batch objects.

    Args:
        data_source (list): contains tuples of (img_id).
        batch_size (int): batch size.
        num_objs_per_batch (int): number of objects in a batch.
    Return
        yields the batch_ids/image_ids/image_indices

    """

    def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
        self.data_source = data_source
        self.sampler = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_objs_per_batch = num_objs_per_batch
        self.batch_count = math.ceil(len(self.data_source) / self.batch_size)

        obj_count = 0
        batch = []
        batches = []
        batches_sums = []
        for i, (k, s) in enumerate(self.data_source.iteritems()):

            if (
                len(batch) < self.batch_size
                and obj_count + s < self.num_objs_per_batch
                and i < len(self.data_source) - 1
            ):
                batch.append(s)
                obj_count += s
            else:
                batches.append(len(batch))
                batches_sums.append(obj_count)
                obj_count = 0
                batch = []

        self.batches = batches
        self.batch_count = len(batches)

    def __iter__(self):
        batch = []
        img_counts_id = 0
        for idx, (k, s) in enumerate(self.data_source.iteritems()):
            if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
                batch.append(s)
            elif len(batch) == self.batches[img_counts_id]:
                gc.collect()
                yield batch
                batch = []
                if img_counts_id < self.batch_count - 1:
                    img_counts_id += 1
                else:
                    break

        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.data_source) // self.batch_size
        else:
            return (len(self.data_source) + self.batch_size - 1) // self.batch_size

As SyntheticDataset's __getitem__ was receiving a list of indices, the simplest solution would just iterate over the indices and retrieve a list of samples. You may just have to collate the output differently in order to feed it to your model.

For the BalancedObjectsSampler, I calculated the size of each batch within the __init__ and used it in __iter__ to assemble the batches.

NOTE: This will still fail if your num_workers > 0 for you are trying to pack at most 1500 objects into a batch - and usually one worker loads one batch at a time. Hence, you have to re-assess your num_objs_per_batch when considering using multiprocessing.

Upbraid answered 17/3, 2022 at 19:22 Comment(1)
So true, thank you very much for taking the time. You have correctly addressed the issue.Venlo

© 2022 - 2024 — McMap. All rights reserved.