From b7f0122c8dd741ca71803ff81bd3ebd7a4e1070e Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 7 Dec 2018 17:06:05 -0800 Subject: [PATCH] fix the situation where idx didn't align with rec (#13550) minor fix the image.py add last_batch_handle for imagedeiter remove the label type refactor the imageiter unit test fix the trailing whitespace fix coding style add new line move helper function to the top of the file --- python/mxnet/image/detection.py | 64 ++++++++-- python/mxnet/image/image.py | 5 +- tests/python/unittest/test_image.py | 184 +++++++++++++++------------- 3 files changed, 157 insertions(+), 96 deletions(-) diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py index b27917c86238..d5b5ecab528a 100644 --- a/python/mxnet/image/detection.py +++ b/python/mxnet/image/detection.py @@ -658,19 +658,26 @@ class ImageDetIter(ImageIter): Data name for provided symbols. label_name : str Name for detection labels + last_batch_handle : str, optional + How to handle the last batch. + This parameter can be 'pad'(default), 'discard' or 'roll_over'. + If 'pad', the last batch will be padded with data starting from the begining + If 'discard', the last batch will be discarded + If 'roll_over', the remaining elements will be rolled over to the next iteration kwargs : ... More arguments for creating augmenter. See mx.image.CreateDetAugmenter. """ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, - data_name='data', label_name='label', **kwargs): + data_name='data', label_name='label', last_batch_handle='pad', **kwargs): super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape, path_imgrec=path_imgrec, path_imglist=path_imglist, path_root=path_root, path_imgidx=path_imgidx, shuffle=shuffle, part_index=part_index, num_parts=num_parts, aug_list=[], imglist=imglist, - data_name=data_name, label_name=label_name) + data_name=data_name, label_name=label_name, + last_batch_handle=last_batch_handle) if aug_list is None: self.auglist = CreateDetAugmenter(data_shape, **kwargs) @@ -751,14 +758,10 @@ def reshape(self, data_shape=None, label_shape=None): self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)] self.label_shape = label_shape - def next(self): - """Override the function for returning next batch.""" + def _batchify(self, batch_data, batch_label, start=0): + """Override the helper function for batchifying data""" + i = start batch_size = self.batch_size - c, h, w = self.data_shape - batch_data = nd.zeros((batch_size, c, h, w)) - batch_label = nd.empty(self.provide_label[0][1]) - batch_label[:] = -1 - i = 0 try: while i < batch_size: label, s = self.next_sample() @@ -783,7 +786,48 @@ def next(self): if not i: raise StopIteration - return io.DataBatch([batch_data], [batch_label], batch_size - i) + return i + + def next(self): + """Override the function for returning next batch.""" + batch_size = self.batch_size + c, h, w = self.data_shape + # if last batch data is rolled over + if self._cache_data is not None: + # check both the data and label have values + assert self._cache_label is not None, "_cache_label didn't have values" + assert self._cache_idx is not None, "_cache_idx didn't have values" + batch_data = self._cache_data + batch_label = self._cache_label + i = self._cache_idx + else: + batch_data = nd.zeros((batch_size, c, h, w)) + batch_label = nd.empty(self.provide_label[0][1]) + batch_label[:] = -1 + i = self._batchify(batch_data, batch_label) + # calculate the padding + pad = batch_size - i + # handle padding for the last batch + if pad != 0: + if self.last_batch_handle == 'discard': + raise StopIteration + # if the option is 'roll_over', throw StopIteration and cache the data + elif self.last_batch_handle == 'roll_over' and \ + self._cache_data is None: + self._cache_data = batch_data + self._cache_label = batch_label + self._cache_idx = i + raise StopIteration + else: + _ = self._batchify(batch_data, batch_label, i) + if self.last_batch_handle == 'pad': + self._allow_read = False + else: + self._cache_data = None + self._cache_label = None + self._cache_idx = None + + return io.DataBatch([batch_data], [batch_label], pad=pad) def augmentation_transform(self, data, label): # pylint: disable=arguments-differ """Override Transforms input data with specified augmentations.""" diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index c9a457f5b7e2..9c2a1cbfba2a 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1145,7 +1145,7 @@ def __init__(self, batch_size, data_shape, label_width=1, self.shuffle = shuffle if self.imgrec is None: self.seq = imgkeys - elif shuffle or num_parts > 1: + elif shuffle or num_parts > 1 or path_imgidx: assert self.imgidx is not None self.seq = self.imgidx else: @@ -1261,7 +1261,7 @@ def next(self): i = self._cache_idx # clear the cache data else: - batch_data = nd.empty((batch_size, c, h, w)) + batch_data = nd.zeros((batch_size, c, h, w)) batch_label = nd.empty(self.provide_label[0][1]) i = self._batchify(batch_data, batch_label) # calculate the padding @@ -1285,6 +1285,7 @@ def next(self): self._cache_data = None self._cache_label = None self._cache_idx = None + return io.DataBatch([batch_data], [batch_label], pad=pad) def check_data_shape(self, data_shape): diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 4f66823cdbf1..4063027cc1e5 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -25,6 +25,7 @@ from nose.tools import raises + def _get_data(url, dirname): import os, tarfile download(url, dirname=dirname, overwrite=False) @@ -50,6 +51,62 @@ def _generate_objects(): label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist() return [2, 5] + label +def _test_imageiter_last_batch(imageiter_list, assert_data_shape): + test_iter = imageiter_list[0] + # test batch data shape + for _ in range(3): + for batch in test_iter: + assert batch.data[0].shape == assert_data_shape + test_iter.reset() + # test last batch handle(discard) + test_iter = imageiter_list[1] + i = 0 + for batch in test_iter: + i += 1 + assert i == 5 + # test last_batch_handle(pad) + test_iter = imageiter_list[2] + i = 0 + for batch in test_iter: + if i == 0: + first_three_data = batch.data[0][:2] + if i == 5: + last_three_data = batch.data[0][1:] + i += 1 + assert i == 6 + assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) + # test last_batch_handle(roll_over) + test_iter = imageiter_list[3] + i = 0 + for batch in test_iter: + if i == 0: + first_image = batch.data[0][0] + i += 1 + assert i == 5 + test_iter.reset() + first_batch_roll_over = test_iter.next() + assert np.array_equal( + first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) + assert first_batch_roll_over.pad == 2 + # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over + for _ in test_iter: + pass + test_iter.reset() + first_batch_roll_over_twice = test_iter.next() + assert np.array_equal( + first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) + assert first_batch_roll_over_twice.pad == 1 + # we've called next once + i = 1 + for _ in test_iter: + i += 1 + # test the third epoch with size 6 + assert i == 6 + # test shuffle option for sanity test + test_iter = imageiter_list[4] + for _ in test_iter: + pass + class TestImage(unittest.TestCase): IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz" @@ -151,86 +208,32 @@ def test_color_normalize(self): assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) def test_imageiter(self): - def check_imageiter(dtype='float32'): - im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] - fname = './data/test_imageiter.lst' - file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) - for k, x in enumerate(TestImage.IMAGES)] - with open(fname, 'w') as f: - for line in file_list: - f.write(line + '\n') - - test_list = ['imglist', 'path_imglist'] + im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] + fname = './data/test_imageiter.lst' + file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) + for k, x in enumerate(TestImage.IMAGES)] + with open(fname, 'w') as f: + for line in file_list: + f.write(line + '\n') + test_list = ['imglist', 'path_imglist'] + for dtype in ['int32', 'float32', 'int64', 'float64']: for test in test_list: imglist = im_list if test == 'imglist' else None path_imglist = fname if test == 'path_imglist' else None - - test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype) - # test batch data shape - for _ in range(3): - for batch in test_iter: - assert batch.data[0].shape == (2, 3, 224, 224) - test_iter.reset() - # test last batch handle(discard) - test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard') - i = 0 - for batch in test_iter: - i += 1 - assert i == 5 - # test last_batch_handle(pad) - test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') - i = 0 - for batch in test_iter: - if i == 0: - first_three_data = batch.data[0][:2] - if i == 5: - last_three_data = batch.data[0][1:] - i += 1 - assert i == 6 - assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) - # test last_batch_handle(roll_over) - test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over') - i = 0 - for batch in test_iter: - if i == 0: - first_image = batch.data[0][0] - i += 1 - assert i == 5 - test_iter.reset() - first_batch_roll_over = test_iter.next() - assert np.array_equal( - first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) - assert first_batch_roll_over.pad == 2 - # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over - for _ in test_iter: - pass - test_iter.reset() - first_batch_roll_over_twice = test_iter.next() - assert np.array_equal( - first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) - assert first_batch_roll_over_twice.pad == 1 - # we've called next once - i = 1 - for _ in test_iter: - i += 1 - # test the third epoch with size 6 - assert i == 6 - # test shuffle option for sanity test - test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') - for _ in test_iter: - pass - - for dtype in ['int32', 'float32', 'int64', 'float64']: - check_imageiter(dtype) - - # test with default dtype - check_imageiter() + imageiter_list = [ + mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype), + mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'), + mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'), + mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'), + mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') + ] + _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224)) @with_seed() def test_augmenters(self): @@ -259,16 +262,20 @@ def test_image_detiter(self): im_list = [_generate_objects() + [x] for x in TestImage.IMAGES] det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='') for _ in range(3): - for batch in det_iter: + for _ in det_iter: pass - det_iter.reset() - + det_iter.reset() val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='') det_iter = val_iter.sync_label_shape(det_iter) assert det_iter.data_shape == val_iter.data_shape assert det_iter.label_shape == val_iter.label_shape - # test file list + # test batch_size is not divisible by number of images + det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='') + for _ in det_iter: + pass + + # test file list with last batch handle fname = './data/test_imagedetiter.lst' im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)] with open(fname, 'w') as f: @@ -276,10 +283,19 @@ def test_image_detiter(self): line = '\t'.join([str(k) for k in line]) f.write(line + '\n') - det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname, - path_root='') - for batch in det_iter: - pass + imageiter_list = [ + mx.image.ImageDetIter(2, (3, 400, 400), + path_imglist=fname, path_root=''), + mx.image.ImageDetIter(3, (3, 400, 400), + path_imglist=fname, path_root='', last_batch_handle='discard'), + mx.image.ImageDetIter(3, (3, 400, 400), + path_imglist=fname, path_root='', last_batch_handle='pad'), + mx.image.ImageDetIter(3, (3, 400, 400), + path_imglist=fname, path_root='', last_batch_handle='roll_over'), + mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True, + path_imglist=fname, path_root='', last_batch_handle='pad') + ] + _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400)) def test_det_augmenters(self): # only test if all augmenters will work