Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Rewrite dataloader, improves responsiveness and reliability #13447

Merged
merged 8 commits into from
Nov 30, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 189 additions & 31 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

from . import sampler as _sampler
from ... import nd, context
from ...recordio import MXRecordIO

if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
Expand Down Expand Up @@ -159,37 +158,17 @@ def _as_in_context(data, ctx):
return [_as_in_context(d, ctx) for d in data]
return data

def _recursive_fork_recordio(obj, depth, max_depth=1000):
"""Recursively find instance of MXRecordIO and reset file handler.
This is required for MXRecordIO which holds a C pointer to a opened file after fork.
"""
if depth >= max_depth:
return
if isinstance(obj, MXRecordIO):
obj.close()
obj.open() # re-obtain file hanlder in new process
elif (hasattr(obj, '__dict__')):
for _, v in obj.__dict__.items():
_recursive_fork_recordio(v, depth + 1, max_depth)

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
# re-fork a new recordio handler in new process if applicable
# for a dataset with transform function, the depth of MXRecordIO is 1
# for a lazy transformer, the depth is 2
# for a user defined transformer, the depth is unknown, try a reasonable depth
limit = sys.getrecursionlimit()
max_recursion_depth = min(limit - 5, max(10, limit // 2))
_recursive_fork_recordio(dataset, 0, max_recursion_depth)

def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
while True:
idx, samples = key_queue.get()
if idx is None:
break
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))

def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
Expand All @@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non
data_buffer[idx] = batch


class _MultiWorkerIter(object):
"""Interal multi-worker iterator for DataLoader."""
class _MultiWorkerIterV1(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=worker_loop):
worker_fn=worker_loop_v1):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
Expand Down Expand Up @@ -237,7 +216,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=
self._workers = workers

self._fetcher = threading.Thread(
target=fetcher_loop,
target=fetcher_loop_v1,
args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
self._fetcher.daemon = True
self._fetcher.start()
Expand Down Expand Up @@ -299,7 +278,7 @@ def shutdown(self):
self._shutdown = True


class DataLoader(object):
class DataLoaderV1(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the new DataLoader preserves the API, so why not remove DataLoaderV1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If some specific implementations were relying on the old queue based methods, I think leaving the older version is still preferable for some user, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about it. If you deem the reliance on queue as reliance on an undocumented implementation detail of V1, then it may be better to remove from Gluon and copy the V1 code over to the specific implementation? It may be reduce the number of APIs to maintain in the future. But I'm also OK with keeping it here if you prefer

"""Loads data from a dataset and returns mini-batches of data.

Parameters
Expand Down Expand Up @@ -390,8 +369,187 @@ def same_process_iter():
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)
return _MultiWorkerIterV1(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)

def __len__(self):
return len(self._batch_sampler)

_worker_dataset = None
def _worker_fn(samples, batchify_fn):
"""Function for processing data in worker process."""
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
batch = batchify_fn([_worker_dataset[i] for i in samples])
batch = [batch] if not isinstance(batch, (list, tuple)) else batch
ret = [reduce_ndarray(x)[1] for x in batch] # reduce_ndarray(x)[0] is the rebuild function
return ret

class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
# pre-fetch
for _ in range(prefetch):
self._push_next()

def __len__(self):
return len(self._batch_sampler)

def _push_next(self):
"""Assign next batch workload to workers."""
r = next(self._iter, None)
if r is None:
return
async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1

def __next__(self):
self._push_next()
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, "Data buffer should be empty at this moment"
raise StopIteration

assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
shared_batch = ret.get()
batch = tuple([rebuild_ndarray(*x) for x in shared_batch])
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch

def next(self):
return self.__next__()

def __iter__(self):
return self


class DataLoader(object):
"""Loads data from a dataset and returns mini-batches of data.

Parameters
----------
dataset : Dataset
Source dataset. Note that numpy and mxnet arrays can be directly used
as a Dataset.
batch_size : int
Size of mini-batch.
shuffle : bool
Whether to shuffle the samples.
sampler : Sampler
The sampler to use. Either specify sampler or shuffle, not both.
last_batch : {'keep', 'discard', 'rollover'}
How to handle the last batch if batch_size does not evenly divide
`len(dataset)`.

keep - A batch with less samples than previous batches is returned.
discard - The last batch is discarded if its incomplete.
rollover - The remaining samples are rolled over to the next epoch.
batch_sampler : Sampler
A sampler that returns mini-batches. Do not specify batch_size,
shuffle, sampler, and last_batch if batch_sampler is specified.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to `default_batchify_fn`::

def default_batchify_fn(data):
if isinstance(data[0], nd.NDArray):
return nd.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
return nd.array(data, dtype=data.dtype)

num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, prefetch=None):
self._dataset = dataset
self._pin_memory = pin_memory

if batch_sampler is None:
if batch_size is None:
raise ValueError("batch_size must be specified unless " \
"batch_sampler is specified")
if sampler is None:
if shuffle:
sampler = _sampler.RandomSampler(len(dataset))
else:
sampler = _sampler.SequentialSampler(len(dataset))
elif shuffle:
raise ValueError("shuffle must not be specified if sampler is specified")

batch_sampler = _sampler.BatchSampler(
sampler, batch_size, last_batch if last_batch else 'keep')
elif batch_size is not None or shuffle or sampler is not None or \
last_batch is not None:
raise ValueError("batch_size, shuffle, sampler and last_batch must " \
"not be specified if batch_sampler is specified.")

self._batch_sampler = batch_sampler
self._num_workers = num_workers if num_workers >= 0 else 0
self._worker_pool = None
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if self._num_workers > 0:
def worker_initializer(data):
global _worker_dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break when using multiple DataLoader with different datasets?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is one global dataset per worker process, different data loaders have independent process pools

_worker_dataset = data

self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=worker_initializer, initargs=[self._dataset])
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
else:
self._batchify_fn = default_batchify_fn
else:
self._batchify_fn = batchify_fn

def __iter__(self):
if self._num_workers == 0:
def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, worker_fn=_worker_fn,
prefetch=self._prefetch)

def __len__(self):
return len(self._batch_sampler)
10 changes: 10 additions & 0 deletions python/mxnet/recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Read and write for the RecordIO data format."""
from __future__ import absolute_import
from collections import namedtuple
from multiprocessing import current_process

import ctypes
import struct
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, uri, flag):
self.uri = c_str(uri)
self.handle = RecordIOHandle()
self.flag = flag
self.pid = None
self.is_open = False
self.open()

Expand All @@ -78,6 +80,7 @@ def open(self):
self.writable = False
else:
raise ValueError("Invalid flag %s"%self.flag)
self.pid = current_process().pid
self.is_open = True

def __del__(self):
Expand Down Expand Up @@ -118,6 +121,7 @@ def close(self):
else:
check_call(_LIB.MXRecordIOReaderFree(self.handle))
self.is_open = False
self.pid = None

def reset(self):
"""Resets the pointer to first item.
Expand Down Expand Up @@ -156,6 +160,8 @@ def write(self, buf):
Buffer to write.
"""
assert self.writable
assert self.pid == current_process().pid, \
"writing in different process is forbidden"
check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
ctypes.c_char_p(buf),
ctypes.c_size_t(len(buf))))
Expand All @@ -182,6 +188,10 @@ def read(self):
Buffer read.
"""
assert not self.writable
if not self.pid == current_process().pid:
# in forked process, obtain a new handle
# print("PID not matching, reset")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused code

self.reset()
buf = ctypes.c_char_p()
size = ctypes.c_size_t()
check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
Expand Down