Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add pin_device_id option to Gluon DataLoader #14136

Merged
merged 3 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))

def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False,
pin_device_id=0, data_buffer_lock=None):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
if idx is None:
break
if pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = _as_in_context(batch, context.cpu_pinned(pin_device_id))
else:
batch = _as_in_context(batch, context.cpu())
if data_buffer_lock is not None:
Expand All @@ -188,8 +189,8 @@ def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=

class _MultiWorkerIterV1(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=worker_loop_v1):
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler,
pin_memory=False, pin_device_id=0, worker_fn=worker_loop_v1):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
Expand Down Expand Up @@ -218,7 +219,8 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=

self._fetcher = threading.Thread(
target=fetcher_loop_v1,
args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
args=(self._data_queue, self._data_buffer, pin_memory,
pin_device_id, self._data_buffer_lock))
self._fetcher.daemon = True
self._fetcher.start()

Expand Down Expand Up @@ -323,12 +325,15 @@ def default_batchify_fn(data):
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False):
num_workers=0, pin_memory=False, pin_device_id=0):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -365,13 +370,14 @@ def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIterV1(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)
self._batchify_fn, self._batch_sampler,
self._pin_memory, self._pin_device_id)

def __len__(self):
return len(self._batch_sampler)
Expand Down Expand Up @@ -403,7 +409,7 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0, dataset=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -413,6 +419,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._dataset = dataset
# pre-fetch
for _ in range(prefetch):
Expand Down Expand Up @@ -442,7 +449,7 @@ def __next__(self):
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
Expand Down Expand Up @@ -498,6 +505,8 @@ def default_batchify_fn(data):
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
Expand All @@ -514,9 +523,11 @@ def default_batchify_fn(data):
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, prefetch=None, thread_pool=False):
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool

if batch_sampler is None:
Expand Down Expand Up @@ -562,13 +573,13 @@ def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory,
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None)
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ def test_multi_worker_dataloader_release_pool():
del the_iter
del D


def test_dataloader_context():
X = np.random.uniform(size=(10, 20))
dataset = gluon.data.ArrayDataset(X)
default_dev_id = 0
custom_dev_id = 1

# use non-pinned memory
loader1 = gluon.data.DataLoader(dataset, 8)
for _, x in enumerate(loader1):
assert x.context == context.cpu(default_dev_id)

# use pinned memory with default device id
loader2 = gluon.data.DataLoader(dataset, 8, pin_memory=True)
for _, x in enumerate(loader2):
assert x.context == context.cpu_pinned(default_dev_id)

# use pinned memory with custom device id
loader3 = gluon.data.DataLoader(dataset, 8, pin_memory=True,
pin_device_id=custom_dev_id)
for _, x in enumerate(loader3):
assert x.context == context.cpu_pinned(custom_dev_id)


if __name__ == '__main__':
import nose
nose.runmodule()