diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index f051472d1be7..ed63d1958b58 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -20,6 +20,7 @@ import mxnet.ndarray as nd from mxnet.test_utils import * from mxnet.base import MXNetError +import itertools import numpy as np import os import gzip @@ -32,6 +33,12 @@ import sys from common import assertRaises import unittest +import sys + +if sys.version_info >= (3,0): + from itertools import zip_longest +else: + from itertools import izip_longest as zip_longest def test_MNISTIter(): @@ -427,13 +434,43 @@ def check_CSVIter_synthetic(dtype='float32'): for dtype in ['int32', 'int64', 'float32']: check_CSVIter_synthetic(dtype=dtype) -# @unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359") def test_ImageRecordIter_seed_augmentation(): get_cifar10() seed_aug = 3 + def assert_dataiter_equals(dataiter1, dataiter2): + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + # ensure iterators are of same length + assert(batch1 and batch2) + + # ensure batches are of same length + assert(len(batch1.data) == len(batch2.data)) + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + assert(np.array_equal(data1, data2)) + + def assert_dataiter_not_equals(dataiter1, dataiter2): + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + + # try to ensure iterators are of same length + assert(batch1 and batch2) + + # ensure batches are of same length + assert(len(batch1.data) == len(batch2.data)) + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + if not np.array_equal(data1, data2): + return + assert False, 'Expected data iterators to be different, but they are the same' + # check whether to get constant images after fixing seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -449,11 +486,8 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - test_index = rnd.randint(0, len(batch.data)) - data = batch.data[test_index].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -469,12 +503,12 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_equals(dataiter1, dataiter2) # check whether to get different images after change seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1.reset() + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -490,32 +524,27 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug+1) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(not np.array_equal(data,data2)) + + assert_dataiter_not_equals(dataiter1, dataiter2) # check whether seed_aug changes the iterator behavior - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - test_index = rnd.randint(0, len(batch.data)) - data = batch.data[test_index].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[test_index].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_equals(dataiter1, dataiter2) if __name__ == "__main__": test_NDArrayIter()