Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[Feature] update sharded loader (#468)
Browse files Browse the repository at this point in the history
* update sharded loader

* fix

* fix threadpool

* use thread_pool, test no merge

* fix sharded batch
  • Loading branch information
zhreshold authored and szha committed Dec 21, 2018
1 parent 4b9cb10 commit f523396
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 64 deletions.
206 changes: 159 additions & 47 deletions src/gluonnlp/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,100 @@
"""DataLoader. An extension of Gluon data loader that allows multi-shard sampling."""
__all__ = ['ShardedDataLoader']

import sys
import io
import pickle
import multiprocessing
from multiprocessing.pool import ThreadPool
from mxnet import context
from mxnet.gluon.data.dataloader import DataLoader
from mxnet.gluon.data.dataloader import _MultiWorkerIter, _as_in_context
from mxnet.recordio import MXRecordIO
from mxnet.gluon.data.dataloader import ForkingPickler, _as_in_context
from mxnet.gluon.data.dataloader import default_mp_batchify_fn, default_batchify_fn
from mxnet.gluon.data import sampler as _sampler

_worker_dataset = None
def _worker_initializer(dataset):
"""Initialier for processing pool."""
# global dataset is per-process based and only available in worker processes
# this is only necessary to handle MXIndexedRecordIO because otherwise dataset
# can be passed as argument
global _worker_dataset
_worker_dataset = dataset

def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
# pylint: disable=unused-argument
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
if isinstance(samples[0], (list, tuple)):
batch = [batchify_fn([_worker_dataset[i] for i in shard]) for shard in samples]
else:
batch = batchify_fn([_worker_dataset[i] for i in samples])
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue()

def _thread_worker_fn(samples, batchify_fn, dataset):
"""Threadpool worker function for processing data."""
if isinstance(samples[0], (list, tuple)):
batch = [batchify_fn([dataset[i] for i in shard]) for shard in samples]
else:
batch = batchify_fn([dataset[i] for i in samples])
return batch

class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0, dataset=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._dataset = dataset
# pre-fetch
for _ in range(prefetch):
self._push_next()

def __len__(self):
return len(self._batch_sampler)

def _push_next(self):
"""Assign next batch workload to workers."""
r = next(self._iter, None)
if r is None:
return
async_ret = self._worker_pool.apply_async(
self._worker_fn, (r, self._batchify_fn, self._dataset))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1

def __next__(self):
self._push_next()
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, "Data buffer should be empty at this moment"
raise StopIteration

assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
self._rcvd_idx += 1
return batch

def next(self):
return self.__next__()

def _recursive_fork_recordio(obj, depth, max_depth=1000):
"""Recursively find instance of MXRecordIO and reset file handler.
This is required for MXRecordIO which holds a C pointer to a opened file after fork.
"""
if depth >= max_depth:
return
if isinstance(obj, MXRecordIO):
obj.close()
obj.open() # re-obtain file hanlder in new process
elif (hasattr(obj, '__dict__')):
for _, v in obj.__dict__.items():
_recursive_fork_recordio(v, depth + 1, max_depth)


def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
# re-fork a new recordio handler in new process if applicable
limit = sys.getrecursionlimit()
max_recursion_depth = min(limit - 5, max(10, limit // 2))
_recursive_fork_recordio(dataset, 0, max_recursion_depth)

while True:
idx, samples = key_queue.get()
if idx is None:
break
if isinstance(samples[0], (list, tuple)):
batch = [batchify_fn([dataset[i] for i in shard]) for shard in samples]
else:
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))
def __iter__(self):
return self


class ShardedDataLoader(DataLoader):
class ShardedDataLoader(object):
"""Loads data from a dataset and returns mini-batches of data.
Parameters
Expand Down Expand Up @@ -102,18 +157,63 @@ def default_batchify_fn(data):
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
"""
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
thread_pool : bool, default False
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False):
super(ShardedDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, last_batch=last_batch,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn,
num_workers=num_workers,
pin_memory=pin_memory)

num_workers=0, pin_memory=False, prefetch=None, thread_pool=False):
self._dataset = dataset
self._pin_memory = pin_memory
self._thread_pool = thread_pool

if batch_sampler is None:
if batch_size is None:
raise ValueError("batch_size must be specified unless " \
"batch_sampler is specified")
if sampler is None:
if shuffle:
sampler = _sampler.RandomSampler(len(dataset))
else:
sampler = _sampler.SequentialSampler(len(dataset))
elif shuffle:
raise ValueError("shuffle must not be specified if sampler is specified")

batch_sampler = _sampler.BatchSampler(
sampler, batch_size, last_batch if last_batch else 'keep')
elif batch_size is not None or shuffle or sampler is not None or \
last_batch is not None:
raise ValueError("batch_size, shuffle, sampler and last_batch must " \
"not be specified if batch_sampler is specified.")

self._batch_sampler = batch_sampler
self._num_workers = num_workers if num_workers >= 0 else 0
self._worker_pool = None
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if self._num_workers > 0:
if self._thread_pool:
self._worker_pool = ThreadPool(self._num_workers)
else:
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
else:
self._batchify_fn = default_batchify_fn
else:
self._batchify_fn = batchify_fn

def __iter__(self):
if self._num_workers == 0:
Expand All @@ -133,6 +233,18 @@ def _same_process_iter():
return _same_process_iter()

# multi-worker
return _MultiWorkerIter(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler,
self._pin_memory, worker_loop)
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None)

def __len__(self):
return len(self._batch_sampler)

def __del__(self):
if self._worker_pool:
# manually terminate due to a bug that pool is not automatically terminated
# https://bugs.python.org/issue34172
assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
self._worker_pool.terminate()
38 changes: 21 additions & 17 deletions tests/unittest/train/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,30 @@ def test_sharded_data_loader():
num_buckets=1,
shuffle=False,
num_shards=num_shards)
for num_workers in [0, 1, 2, 3, 4]:
loader = ShardedDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers)
for i, seqs in enumerate(loader):
assert len(seqs) == num_shards
for j in range(num_shards):
if i != len(loader) - 1:
assert mx.test_utils.almost_equal(seqs[j][0].asnumpy(),
X[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
assert mx.test_utils.almost_equal(seqs[j][1].asnumpy(),
Y[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
else:
assert mx.test_utils.almost_equal(seqs[j][0].asnumpy(),
X[(i*num_shards+j)*2-num_shards:
(i*num_shards+j+1)*2-num_shards])
assert mx.test_utils.almost_equal(seqs[j][1].asnumpy(),
Y[(i*num_shards+j)*2-num_shards:
(i*num_shards+j+1)*2-num_shards])
for thread_pool in [True, False]:
for num_workers in [0, 1, 2, 3, 4]:
loader = ShardedDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, thread_pool=thread_pool)
for i, seqs in enumerate(loader):
assert len(seqs) == num_shards
for j in range(num_shards):
if i != len(loader) - 1:
assert mx.test_utils.almost_equal(seqs[j][0].asnumpy(),
X[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
assert mx.test_utils.almost_equal(seqs[j][1].asnumpy(),
Y[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
else:
assert mx.test_utils.almost_equal(seqs[j][0].asnumpy(),
X[(i*num_shards+j)*2-num_shards:
(i*num_shards+j+1)*2-num_shards])
assert mx.test_utils.almost_equal(seqs[j][1].asnumpy(),
Y[(i*num_shards+j)*2-num_shards:
(i*num_shards+j+1)*2-num_shards])

@pytest.mark.remote_required
def test_sharded_data_loader_record_file():
if not hasattr(mx.recordio.MXRecordIO, '_check_pid'):
# skip if mxnet<=1.4.0 detected, some hotfix is not included so recordfile will break
return
# test record file
url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}'
filename = 'val.rec'
Expand Down

0 comments on commit f523396

Please sign in to comment.