This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Rewrite dataloader, improves responsiveness and reliability #13447
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
54c0d0c
fix recordio.py
zhreshold 965dc2e
rewrite dataloader with pool
zhreshold 43e315c
fix batch as tuple
zhreshold 9ed47a3
fix prefetching
zhreshold a400375
fix pylint
zhreshold e5b11d4
picklable function
zhreshold 7de1f2e
use pickle
zhreshold 1b2fc73
add missing commit
zhreshold File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,6 @@ | |
|
||
from . import sampler as _sampler | ||
from ... import nd, context | ||
from ...recordio import MXRecordIO | ||
|
||
if sys.platform == 'darwin' or sys.platform == 'win32': | ||
def rebuild_ndarray(*args): | ||
|
@@ -159,37 +158,17 @@ def _as_in_context(data, ctx): | |
return [_as_in_context(d, ctx) for d in data] | ||
return data | ||
|
||
def _recursive_fork_recordio(obj, depth, max_depth=1000): | ||
"""Recursively find instance of MXRecordIO and reset file handler. | ||
This is required for MXRecordIO which holds a C pointer to a opened file after fork. | ||
""" | ||
if depth >= max_depth: | ||
return | ||
if isinstance(obj, MXRecordIO): | ||
obj.close() | ||
obj.open() # re-obtain file hanlder in new process | ||
elif (hasattr(obj, '__dict__')): | ||
for _, v in obj.__dict__.items(): | ||
_recursive_fork_recordio(v, depth + 1, max_depth) | ||
|
||
def worker_loop(dataset, key_queue, data_queue, batchify_fn): | ||
"""Worker loop for multiprocessing DataLoader.""" | ||
# re-fork a new recordio handler in new process if applicable | ||
# for a dataset with transform function, the depth of MXRecordIO is 1 | ||
# for a lazy transformer, the depth is 2 | ||
# for a user defined transformer, the depth is unknown, try a reasonable depth | ||
limit = sys.getrecursionlimit() | ||
max_recursion_depth = min(limit - 5, max(10, limit // 2)) | ||
_recursive_fork_recordio(dataset, 0, max_recursion_depth) | ||
|
||
def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn): | ||
"""Worker loop for multiprocessing DataLoader.""" | ||
while True: | ||
idx, samples = key_queue.get() | ||
if idx is None: | ||
break | ||
batch = batchify_fn([dataset[i] for i in samples]) | ||
data_queue.put((idx, batch)) | ||
|
||
def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None): | ||
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None): | ||
"""Fetcher loop for fetching data from queue and put in reorder dict.""" | ||
while True: | ||
idx, batch = data_queue.get() | ||
|
@@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non | |
data_buffer[idx] = batch | ||
|
||
|
||
class _MultiWorkerIter(object): | ||
"""Interal multi-worker iterator for DataLoader.""" | ||
class _MultiWorkerIterV1(object): | ||
"""Internal multi-worker iterator for DataLoader.""" | ||
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, | ||
worker_fn=worker_loop): | ||
worker_fn=worker_loop_v1): | ||
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) | ||
self._num_workers = num_workers | ||
self._dataset = dataset | ||
|
@@ -237,7 +216,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= | |
self._workers = workers | ||
|
||
self._fetcher = threading.Thread( | ||
target=fetcher_loop, | ||
target=fetcher_loop_v1, | ||
args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock)) | ||
self._fetcher.daemon = True | ||
self._fetcher.start() | ||
|
@@ -299,7 +278,7 @@ def shutdown(self): | |
self._shutdown = True | ||
|
||
|
||
class DataLoader(object): | ||
class DataLoaderV1(object): | ||
"""Loads data from a dataset and returns mini-batches of data. | ||
|
||
Parameters | ||
|
@@ -390,8 +369,187 @@ def same_process_iter(): | |
return same_process_iter() | ||
|
||
# multi-worker | ||
return _MultiWorkerIter(self._num_workers, self._dataset, | ||
self._batchify_fn, self._batch_sampler, self._pin_memory) | ||
return _MultiWorkerIterV1(self._num_workers, self._dataset, | ||
self._batchify_fn, self._batch_sampler, self._pin_memory) | ||
|
||
def __len__(self): | ||
return len(self._batch_sampler) | ||
|
||
_worker_dataset = None | ||
def _worker_fn(samples, batchify_fn): | ||
"""Function for processing data in worker process.""" | ||
# it is required that each worker process has to fork a new MXIndexedRecordIO handle | ||
# preserving dataset as global variable can save tons of overhead and is safe in new process | ||
global _worker_dataset | ||
batch = batchify_fn([_worker_dataset[i] for i in samples]) | ||
batch = [batch] if not isinstance(batch, (list, tuple)) else batch | ||
ret = [reduce_ndarray(x)[1] for x in batch] # reduce_ndarray(x)[0] is the rebuild function | ||
return ret | ||
|
||
class _MultiWorkerIter(object): | ||
"""Internal multi-worker iterator for DataLoader.""" | ||
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, | ||
worker_fn=_worker_fn, prefetch=0): | ||
self._worker_pool = worker_pool | ||
self._batchify_fn = batchify_fn | ||
self._batch_sampler = batch_sampler | ||
self._data_buffer = {} | ||
self._rcvd_idx = 0 | ||
self._sent_idx = 0 | ||
self._iter = iter(self._batch_sampler) | ||
self._worker_fn = worker_fn | ||
self._pin_memory = pin_memory | ||
# pre-fetch | ||
for _ in range(prefetch): | ||
self._push_next() | ||
|
||
def __len__(self): | ||
return len(self._batch_sampler) | ||
|
||
def _push_next(self): | ||
"""Assign next batch workload to workers.""" | ||
r = next(self._iter, None) | ||
if r is None: | ||
return | ||
async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn)) | ||
self._data_buffer[self._sent_idx] = async_ret | ||
self._sent_idx += 1 | ||
|
||
def __next__(self): | ||
self._push_next() | ||
if self._rcvd_idx == self._sent_idx: | ||
assert not self._data_buffer, "Data buffer should be empty at this moment" | ||
raise StopIteration | ||
|
||
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx" | ||
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing" | ||
ret = self._data_buffer.pop(self._rcvd_idx) | ||
shared_batch = ret.get() | ||
batch = tuple([rebuild_ndarray(*x) for x in shared_batch]) | ||
if self._pin_memory: | ||
batch = _as_in_context(batch, context.cpu_pinned()) | ||
batch = batch[0] if len(batch) == 1 else batch | ||
self._rcvd_idx += 1 | ||
return batch | ||
|
||
def next(self): | ||
return self.__next__() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
|
||
class DataLoader(object): | ||
"""Loads data from a dataset and returns mini-batches of data. | ||
|
||
Parameters | ||
---------- | ||
dataset : Dataset | ||
Source dataset. Note that numpy and mxnet arrays can be directly used | ||
as a Dataset. | ||
batch_size : int | ||
Size of mini-batch. | ||
shuffle : bool | ||
Whether to shuffle the samples. | ||
sampler : Sampler | ||
The sampler to use. Either specify sampler or shuffle, not both. | ||
last_batch : {'keep', 'discard', 'rollover'} | ||
How to handle the last batch if batch_size does not evenly divide | ||
`len(dataset)`. | ||
|
||
keep - A batch with less samples than previous batches is returned. | ||
discard - The last batch is discarded if its incomplete. | ||
rollover - The remaining samples are rolled over to the next epoch. | ||
batch_sampler : Sampler | ||
A sampler that returns mini-batches. Do not specify batch_size, | ||
shuffle, sampler, and last_batch if batch_sampler is specified. | ||
batchify_fn : callable | ||
Callback function to allow users to specify how to merge samples | ||
into a batch. Defaults to `default_batchify_fn`:: | ||
|
||
def default_batchify_fn(data): | ||
if isinstance(data[0], nd.NDArray): | ||
return nd.stack(*data) | ||
elif isinstance(data[0], tuple): | ||
data = zip(*data) | ||
return [default_batchify_fn(i) for i in data] | ||
else: | ||
data = np.asarray(data) | ||
return nd.array(data, dtype=data.dtype) | ||
|
||
num_workers : int, default 0 | ||
The number of multiprocessing workers to use for data preprocessing. | ||
pin_memory : boolean, default False | ||
If ``True``, the dataloader will copy NDArrays into pinned memory | ||
before returning them. Copying from CPU pinned memory to GPU is faster | ||
than from normal CPU memory. | ||
prefetch : int, default is `num_workers * 2` | ||
The number of prefetching batches only works if `num_workers` > 0. | ||
If `prefetch` > 0, it allow worker process to prefetch certain batches before | ||
acquiring data from iterators. | ||
Note that using large prefetching batch will provide smoother bootstrapping performance, | ||
but will consume more shared_memory. Using smaller number may forfeit the purpose of using | ||
multiple worker processes, try reduce `num_workers` in this case. | ||
By default it defaults to `num_workers * 2`. | ||
""" | ||
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, | ||
last_batch=None, batch_sampler=None, batchify_fn=None, | ||
num_workers=0, pin_memory=False, prefetch=None): | ||
self._dataset = dataset | ||
self._pin_memory = pin_memory | ||
|
||
if batch_sampler is None: | ||
if batch_size is None: | ||
raise ValueError("batch_size must be specified unless " \ | ||
"batch_sampler is specified") | ||
if sampler is None: | ||
if shuffle: | ||
sampler = _sampler.RandomSampler(len(dataset)) | ||
else: | ||
sampler = _sampler.SequentialSampler(len(dataset)) | ||
elif shuffle: | ||
raise ValueError("shuffle must not be specified if sampler is specified") | ||
|
||
batch_sampler = _sampler.BatchSampler( | ||
sampler, batch_size, last_batch if last_batch else 'keep') | ||
elif batch_size is not None or shuffle or sampler is not None or \ | ||
last_batch is not None: | ||
raise ValueError("batch_size, shuffle, sampler and last_batch must " \ | ||
"not be specified if batch_sampler is specified.") | ||
|
||
self._batch_sampler = batch_sampler | ||
self._num_workers = num_workers if num_workers >= 0 else 0 | ||
self._worker_pool = None | ||
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) | ||
if self._num_workers > 0: | ||
def worker_initializer(data): | ||
global _worker_dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this break when using multiple DataLoader with different datasets? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, there is one global dataset per worker process, different data loaders have independent process pools |
||
_worker_dataset = data | ||
|
||
self._worker_pool = multiprocessing.Pool( | ||
self._num_workers, initializer=worker_initializer, initargs=[self._dataset]) | ||
if batchify_fn is None: | ||
if num_workers > 0: | ||
self._batchify_fn = default_mp_batchify_fn | ||
else: | ||
self._batchify_fn = default_batchify_fn | ||
else: | ||
self._batchify_fn = batchify_fn | ||
|
||
def __iter__(self): | ||
if self._num_workers == 0: | ||
def same_process_iter(): | ||
for batch in self._batch_sampler: | ||
ret = self._batchify_fn([self._dataset[idx] for idx in batch]) | ||
if self._pin_memory: | ||
ret = _as_in_context(ret, context.cpu_pinned()) | ||
yield ret | ||
return same_process_iter() | ||
|
||
# multi-worker | ||
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler, | ||
pin_memory=self._pin_memory, worker_fn=_worker_fn, | ||
prefetch=self._prefetch) | ||
|
||
def __len__(self): | ||
return len(self._batch_sampler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
"""Read and write for the RecordIO data format.""" | ||
from __future__ import absolute_import | ||
from collections import namedtuple | ||
from multiprocessing import current_process | ||
|
||
import ctypes | ||
import struct | ||
|
@@ -65,6 +66,7 @@ def __init__(self, uri, flag): | |
self.uri = c_str(uri) | ||
self.handle = RecordIOHandle() | ||
self.flag = flag | ||
self.pid = None | ||
self.is_open = False | ||
self.open() | ||
|
||
|
@@ -78,6 +80,7 @@ def open(self): | |
self.writable = False | ||
else: | ||
raise ValueError("Invalid flag %s"%self.flag) | ||
self.pid = current_process().pid | ||
self.is_open = True | ||
|
||
def __del__(self): | ||
|
@@ -118,6 +121,7 @@ def close(self): | |
else: | ||
check_call(_LIB.MXRecordIOReaderFree(self.handle)) | ||
self.is_open = False | ||
self.pid = None | ||
|
||
def reset(self): | ||
"""Resets the pointer to first item. | ||
|
@@ -156,6 +160,8 @@ def write(self, buf): | |
Buffer to write. | ||
""" | ||
assert self.writable | ||
assert self.pid == current_process().pid, \ | ||
"writing in different process is forbidden" | ||
check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, | ||
ctypes.c_char_p(buf), | ||
ctypes.c_size_t(len(buf)))) | ||
|
@@ -182,6 +188,10 @@ def read(self): | |
Buffer read. | ||
""" | ||
assert not self.writable | ||
if not self.pid == current_process().pid: | ||
# in forked process, obtain a new handle | ||
# print("PID not matching, reset") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove unused code |
||
self.reset() | ||
buf = ctypes.c_char_p() | ||
size = ctypes.c_size_t() | ||
check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the new
DataLoader
preserves the API, so why not removeDataLoaderV1
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If some specific implementations were relying on the old queue based methods, I think leaving the older version is still preferable for some user, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about it. If you deem the reliance on queue as reliance on an undocumented implementation detail of V1, then it may be better to remove from Gluon and copy the V1 code over to the specific implementation? It may be reduce the number of APIs to maintain in the future. But I'm also OK with keeping it here if you prefer