Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix the situation where idx didn't align with rec (#13550)
Browse files Browse the repository at this point in the history
minor fix the image.py

add last_batch_handle for imagedeiter

remove the label type

refactor the imageiter unit test

fix the trailing whitespace

fix coding style

add new line

move helper function to the top of the file
  • Loading branch information
stu1130 authored and zhreshold committed Dec 8, 2018
1 parent 2d08816 commit b7f0122
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 96 deletions.
64 changes: 54 additions & 10 deletions python/mxnet/image/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
Data name for provided symbols.
label_name : str
Name for detection labels
last_batch_handle : str, optional
How to handle the last batch.
This parameter can be 'pad'(default), 'discard' or 'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
"""
def __init__(self, batch_size, data_shape,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='label', **kwargs):
data_name='data', label_name='label', last_batch_handle='pad', **kwargs):
super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape,
path_imgrec=path_imgrec, path_imglist=path_imglist,
path_root=path_root, path_imgidx=path_imgidx,
shuffle=shuffle, part_index=part_index,
num_parts=num_parts, aug_list=[], imglist=imglist,
data_name=data_name, label_name=label_name)
data_name=data_name, label_name=label_name,
last_batch_handle=last_batch_handle)

if aug_list is None:
self.auglist = CreateDetAugmenter(data_shape, **kwargs)
Expand Down Expand Up @@ -751,14 +758,10 @@ def reshape(self, data_shape=None, label_shape=None):
self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)]
self.label_shape = label_shape

def next(self):
"""Override the function for returning next batch."""
def _batchify(self, batch_data, batch_label, start=0):
"""Override the helper function for batchifying data"""
i = start
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
batch_label[:] = -1
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
Expand All @@ -783,7 +786,48 @@ def next(self):
if not i:
raise StopIteration

return io.DataBatch([batch_data], [batch_label], batch_size - i)
return i

def next(self):
"""Override the function for returning next batch."""
batch_size = self.batch_size
c, h, w = self.data_shape
# if last batch data is rolled over
if self._cache_data is not None:
# check both the data and label have values
assert self._cache_label is not None, "_cache_label didn't have values"
assert self._cache_idx is not None, "_cache_idx didn't have values"
batch_data = self._cache_data
batch_label = self._cache_label
i = self._cache_idx
else:
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
batch_label[:] = -1
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
# handle padding for the last batch
if pad != 0:
if self.last_batch_handle == 'discard':
raise StopIteration
# if the option is 'roll_over', throw StopIteration and cache the data
elif self.last_batch_handle == 'roll_over' and \
self._cache_data is None:
self._cache_data = batch_data
self._cache_label = batch_label
self._cache_idx = i
raise StopIteration
else:
_ = self._batchify(batch_data, batch_label, i)
if self.last_batch_handle == 'pad':
self._allow_read = False
else:
self._cache_data = None
self._cache_label = None
self._cache_idx = None

return io.DataBatch([batch_data], [batch_label], pad=pad)

def augmentation_transform(self, data, label): # pylint: disable=arguments-differ
"""Override Transforms input data with specified augmentations."""
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
elif shuffle or num_parts > 1:
elif shuffle or num_parts > 1 or path_imgidx:
assert self.imgidx is not None
self.seq = self.imgidx
else:
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def next(self):
i = self._cache_idx
# clear the cache data
else:
batch_data = nd.empty((batch_size, c, h, w))
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
Expand All @@ -1285,6 +1285,7 @@ def next(self):
self._cache_data = None
self._cache_label = None
self._cache_idx = None

return io.DataBatch([batch_data], [batch_label], pad=pad)

def check_data_shape(self, data_shape):
Expand Down
184 changes: 100 additions & 84 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from nose.tools import raises


def _get_data(url, dirname):
import os, tarfile
download(url, dirname=dirname, overwrite=False)
Expand All @@ -50,6 +51,62 @@ def _generate_objects():
label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
return [2, 5] + label

def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
test_iter = imageiter_list[0]
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == assert_data_shape
test_iter.reset()
# test last batch handle(discard)
test_iter = imageiter_list[1]
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = imageiter_list[2]
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = imageiter_list[3]
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
# test shuffle option for sanity test
test_iter = imageiter_list[4]
for _ in test_iter:
pass


class TestImage(unittest.TestCase):
IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
Expand Down Expand Up @@ -151,86 +208,32 @@ def test_color_normalize(self):
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)

def test_imageiter(self):
def check_imageiter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']
for dtype in ['int32', 'float32', 'int64', 'float64']:
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype)
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == (2, 3, 224, 224)
test_iter.reset()
# test last batch handle(discard)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
# test shuffle option for sanity test
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
for _ in test_iter:
pass

for dtype in ['int32', 'float32', 'int64', 'float64']:
check_imageiter(dtype)

# test with default dtype
check_imageiter()
imageiter_list = [
mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over'),
mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
]
_test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))

@with_seed()
def test_augmenters(self):
Expand Down Expand Up @@ -259,27 +262,40 @@ def test_image_detiter(self):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
for _ in range(3):
for batch in det_iter:
for _ in det_iter:
pass
det_iter.reset()

det_iter.reset()
val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
det_iter = val_iter.sync_label_shape(det_iter)
assert det_iter.data_shape == val_iter.data_shape
assert det_iter.label_shape == val_iter.label_shape

# test file list
# test batch_size is not divisible by number of images
det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, path_root='')
for _ in det_iter:
pass

# test file list with last batch handle
fname = './data/test_imagedetiter.lst'
im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in im_list:
line = '\t'.join([str(k) for k in line])
f.write(line + '\n')

det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
path_root='')
for batch in det_iter:
pass
imageiter_list = [
mx.image.ImageDetIter(2, (3, 400, 400),
path_imglist=fname, path_root=''),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='discard'),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='pad'),
mx.image.ImageDetIter(3, (3, 400, 400),
path_imglist=fname, path_root='', last_batch_handle='roll_over'),
mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
path_imglist=fname, path_root='', last_batch_handle='pad')
]
_test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))

def test_det_augmenters(self):
# only test if all augmenters will work
Expand Down

0 comments on commit b7f0122

Please sign in to comment.