diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index f051472d1be7..351231bb47fd 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -32,6 +32,10 @@ import sys from common import assertRaises import unittest +try: + from itertools import izip_longest as zip_longest +except: + from itertools import zip_longest def test_MNISTIter(): @@ -427,13 +431,56 @@ def check_CSVIter_synthetic(dtype='float32'): for dtype in ['int32', 'int64', 'float32']: check_CSVIter_synthetic(dtype=dtype) -# @unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359") def test_ImageRecordIter_seed_augmentation(): get_cifar10() seed_aug = 3 + def assert_dataiter_items_equals(dataiter1, dataiter2): + """ + Asserts that two data iterators have the same numbner of batches, + that the batches have the same number of items, and that the items + are the equal. + """ + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + + # ensure iterators contain the same number of batches + # zip_longest will return None if on of the iterators have run out of batches + assert batch1 and batch2, 'The iterators do not contain the same number of batches' + + # ensure batches are of same length + assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + assert(np.array_equal(data1, data2)) + + def assert_dataiter_items_not_equals(dataiter1, dataiter2): + """ + Asserts that two data iterators have the same numbner of batches, + that the batches have the same number of items, and that the items + are the _not_ equal. + """ + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + + # ensure iterators are of same length + # zip_longest will return None if on of the iterators have run out of batches + assert batch1 and batch2, 'The iterators do not contain the same number of batches' + + # ensure batches are of same length + assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + if not np.array_equal(data1, data2): + return + assert False, 'Expected data iterators to be different, but they are the same' + # check whether to get constant images after fixing seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -449,11 +496,8 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - test_index = rnd.randint(0, len(batch.data)) - data = batch.data[test_index].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -469,12 +513,12 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_items_equals(dataiter1, dataiter2) # check whether to get different images after change seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1.reset() + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -490,32 +534,27 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug+1) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(not np.array_equal(data,data2)) + + assert_dataiter_items_not_equals(dataiter1, dataiter2) # check whether seed_aug changes the iterator behavior - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - test_index = rnd.randint(0, len(batch.data)) - data = batch.data[test_index].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_items_equals(dataiter1, dataiter2) if __name__ == "__main__": test_NDArrayIter()