Skip to content

Commit

Permalink
fix recordfile dataset with multi worker (apache#11370)
Browse files Browse the repository at this point in the history
* fix recordfile dataset with multi worker

* fix another test

* fix
  • Loading branch information
zhreshold authored and szha committed Jun 25, 2018
1 parent 1c7e9b6 commit 4713045
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def default_mp_batchify_fn(data):

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
dataset._fork()
while True:
idx, samples = key_queue.get()
if idx is None:
Expand Down
13 changes: 11 additions & 2 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def base_fn(x, *args):
return fn(x)
return self.transform(base_fn, lazy)

def _fork(self):
"""Protective operations required when launching multiprocess workers."""
# for non file descriptor related datasets, just skip
pass


class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Expand Down Expand Up @@ -173,8 +178,12 @@ class RecordFileDataset(Dataset):
Path to rec file.
"""
def __init__(self, filename):
idx_file = os.path.splitext(filename)[0] + '.idx'
self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r')
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
self._fork()

def _fork(self):
self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')

def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def test_recordimage_dataset():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
# This test is pointless on Windows because Windows doesn't fork
if platform.system() != 'Windows':
recfile = prepare_record()
dataset = gluon.data.vision.ImageRecordDataset(recfile)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def test_ImageRecordIter_seed_augmentation():
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3)
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)

Expand Down

0 comments on commit 4713045

Please sign in to comment.