It seems like serialization and deserialization associated with python's multiprocessing limit the benefits of processing data in parallel.
In the following example, I create a custom iterable that returns a numpy array. As the size of the numpy array increases, the data fetching process becomes the bottleneck. This is expected. However, I would expect increasing num_worker
and prefetch_factor
would reduce this bottleneck by preparing batches in advance. But I do not see this behavior in the example below.
I test two cases where MyIterable
returns
- small object
np.array((10, 150))
- large object
np.array((1000, 150))
The average time to process a batch in both scenarios is as follows:
# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253
For the small object, the time for each batch drops as expected when num_workers
are increased. But for larger object, it does not change much. I attribute it to the fact the the worker process has to serialize the np object and the main process would then deserialize it. The larger the object, the more time it will take.
However, with large enough num_worker
and prefetch_factor
, shouldn't the queue in the dataloader be always filled such that data fetching is not the bottleneck?
Moreover, changing the prefetch_factor
does not change anything. What is the point of prefetch_factor
? The document says the main process pre-loads num_worker * prefetch_factor
batches but as you can there is no effect in reducing the bottleneck.
I have added a more detailed step-by-step analysis in this question for reference.
import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset
def collate_fn(records):
# some custom collation function
return records
class MyIterable(object):
def __init__(self, n):
self.n = n
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self.n:
sleep(0.003125) # simulates data fetch time
# return np.random.random((10, 150)) # small data item
return np.random.random((1000, 150)) # large data item
else:
raise StopIteration
class MyIterableDataset(IterableDataset):
def __init__(self, n):
super(MyIterableDataset).__init__()
self.n = n
def __iter__(self):
return MyIterable(self.n)
def get_performance_metrics(num_workers):
ds = MyIterableDataset(n=10000)
if num_workers == 0:
dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
else:
dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
batch_size=128, collate_fn=collate_fn,
multiprocessing_context='spawn')
warmup = 5
times = []
t0 = time.perf_counter()
for i, batch in enumerate(dl):
sleep(0.05) # simulates train step
e = time.perf_counter()
if i >= warmup:
times.append(e - t0)
t0 = time.perf_counter()
if i >= 20:
break
print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')
if __name__ == '__main__':
num_worker_options = [0, 2, 4, 6, 8]
for n in num_worker_options:
get_performance_metrics(n)