diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 29b9b81aca04..eb1eb419cd02 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -151,6 +151,7 @@ def default_mp_batchify_fn(data): def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" + dataset._fork() while True: idx, samples = key_queue.get() if idx is None: diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index bf5fa0a6d1e1..13e2b57a8c59 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -94,6 +94,11 @@ def base_fn(x, *args): return fn(x) return self.transform(base_fn, lazy) + def _fork(self): + """Protective operations required when launching multiprocess workers.""" + # for non file descriptor related datasets, just skip + pass + class SimpleDataset(Dataset): """Simple Dataset wrapper for lists and arrays. @@ -173,8 +178,12 @@ class RecordFileDataset(Dataset): Path to rec file. """ def __init__(self, filename): - idx_file = os.path.splitext(filename)[0] + '.idx' - self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r') + self.idx_file = os.path.splitext(filename)[0] + '.idx' + self.filename = filename + self._fork() + + def _fork(self): + self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r') def __getitem__(self, idx): return self._record.read_idx(self._record.keys[idx]) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index ef2ba2ab9b28..043804487b5e 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -72,6 +72,18 @@ def test_recordimage_dataset(): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i +@with_seed() +def test_recordimage_dataset_with_data_loader_multiworker(): + # This test is pointless on Windows because Windows doesn't fork + if platform.system() != 'Windows': + recfile = prepare_record() + dataset = gluon.data.vision.ImageRecordDataset(recfile) + loader = gluon.data.DataLoader(dataset, 1, num_workers=5) + + for i, (x, y) in enumerate(loader): + assert x.shape[0] == 1 and x.shape[3] == 3 + assert y.asscalar() == i + @with_seed() def test_sampler(): seq_sampler = gluon.data.SequentialSampler(10) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index f0928a6b61a7..dbd327d429f7 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -407,7 +407,8 @@ def test_ImageRecordIter_seed_augmentation(): mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), - batch_size=3) + batch_size=3, + seed_aug=seed_aug) batch = dataiter.next() data = batch.data[0].asnumpy().astype(np.uint8)