From ccab255620c7c01b863db8dc76b9549875ccc38f Mon Sep 17 00:00:00 2001 From: perdasilva Date: Fri, 18 Jan 2019 10:45:28 -0800 Subject: [PATCH] test_ImageRecordIter_seed_augmentation flaky test fix (#12485) * Moves seed_aug parameter to ImageRecParserParam and re-seeds RNG before each augmentation to guarantee reproducibilit * Update image record iterator tests to check the whole iterator not only first image --- src/io/image_aug_default.cc | 13 +----- src/io/image_iter_common.h | 4 ++ src/io/iter_image_recordio_2.cc | 7 +++ tests/python/unittest/test_io.py | 79 ++++++++++++++++++++++++-------- 4 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc index f31664709bd5..cd06de2b2ad1 100644 --- a/src/io/image_aug_default.cc +++ b/src/io/image_aug_default.cc @@ -97,8 +97,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter seed_aug; // declare parameters DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) { @@ -188,8 +186,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter()) - .describe("Random seed for augmentations."); } }; @@ -208,9 +204,7 @@ std::vector ListDefaultAugParams() { class DefaultImageAugmenter : public ImageAugmenter { public: // contructor - DefaultImageAugmenter() { - seed_init_state = false; - } + DefaultImageAugmenter() {} void Init(const std::vector >& kwargs) override { std::vector > kwargs_left; kwargs_left = param_.InitAllowUnknown(kwargs); @@ -250,10 +244,6 @@ class DefaultImageAugmenter : public ImageAugmenter { } cv::Mat Process(const cv::Mat &src, std::vector *label, common::RANDOM_ENGINE *prnd) override { - if (!seed_init_state && param_.seed_aug.has_value()) { - prnd->seed(param_.seed_aug.value()); - seed_init_state = true; - } using mshadow::index_t; bool is_cropped = false; @@ -558,7 +548,6 @@ class DefaultImageAugmenter : public ImageAugmenter { DefaultImageAugmentParam param_; /*! \brief list of possible rotate angle */ std::vector rotate_list_; - bool seed_init_state; }; ImageAugmenter* ImageAugmenter::Create(const std::string& name) { diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h index a2324a4b5c5b..c9e3933ade28 100644 --- a/src/io/image_iter_common.h +++ b/src/io/image_iter_common.h @@ -131,6 +131,8 @@ struct ImageRecParserParam : public dmlc::Parameter { size_t shuffle_chunk_size; /*! \brief the seed for chunk shuffling*/ int shuffle_chunk_seed; + /*! \brief random seed for augmentations */ + dmlc::optional seed_aug; // declare parameters DMLC_DECLARE_PARAMETER(ImageRecParserParam) { @@ -165,6 +167,8 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("The data shuffle buffer size in MB. Only valid if shuffle is true."); DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0) .describe("The random seed for shuffling"); + DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional()) + .describe("Random seed for augmentations."); } }; diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index b567c729736c..89f7753983db 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -519,6 +519,13 @@ inline size_t ImageRecordIOParser2::ParseChunk(DType* data_dptr, real_t* cv::Mat res; rec.Load(blob.dptr, blob.size); cv::Mat buf(1, rec.content_size, CV_8U, rec.content); + + // If augmentation seed is supplied + // Re-seed RNG to guarantee reproducible results + if (param_.seed_aug.has_value()) { + prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic); + } + switch (param_.data_shape[0]) { case 1: #if MXNET_USE_LIBJPEG_TURBO diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 0641f235aa71..351231bb47fd 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -32,6 +32,10 @@ import sys from common import assertRaises import unittest +try: + from itertools import izip_longest as zip_longest +except: + from itertools import zip_longest def test_MNISTIter(): @@ -427,13 +431,56 @@ def check_CSVIter_synthetic(dtype='float32'): for dtype in ['int32', 'int64', 'float32']: check_CSVIter_synthetic(dtype=dtype) -@unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359") def test_ImageRecordIter_seed_augmentation(): get_cifar10() seed_aug = 3 + def assert_dataiter_items_equals(dataiter1, dataiter2): + """ + Asserts that two data iterators have the same numbner of batches, + that the batches have the same number of items, and that the items + are the equal. + """ + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + + # ensure iterators contain the same number of batches + # zip_longest will return None if on of the iterators have run out of batches + assert batch1 and batch2, 'The iterators do not contain the same number of batches' + + # ensure batches are of same length + assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + assert(np.array_equal(data1, data2)) + + def assert_dataiter_items_not_equals(dataiter1, dataiter2): + """ + Asserts that two data iterators have the same numbner of batches, + that the batches have the same number of items, and that the items + are the _not_ equal. + """ + for batch1, batch2 in zip_longest(dataiter1, dataiter2): + + # ensure iterators are of same length + # zip_longest will return None if on of the iterators have run out of batches + assert batch1 and batch2, 'The iterators do not contain the same number of batches' + + # ensure batches are of same length + assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length' + + # ensure batch data is the same + for i in range(0, len(batch1.data)): + data1 = batch1.data[i].asnumpy().astype(np.uint8) + data2 = batch2.data[i].asnumpy().astype(np.uint8) + if not np.array_equal(data1, data2): + return + assert False, 'Expected data iterators to be different, but they are the same' + # check whether to get constant images after fixing seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -449,10 +496,8 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -468,12 +513,12 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_items_equals(dataiter1, dataiter2) # check whether to get different images after change seed_aug - dataiter = mx.io.ImageRecordIter( + dataiter1.reset() + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, @@ -489,31 +534,27 @@ def test_ImageRecordIter_seed_augmentation(): random_h=10, max_shear_ratio=2, seed_aug=seed_aug+1) - batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) - assert(not np.array_equal(data,data2)) + + assert_dataiter_items_not_equals(dataiter1, dataiter2) # check whether seed_aug changes the iterator behavior - dataiter = mx.io.ImageRecordIter( + dataiter1 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - data = batch.data[0].asnumpy().astype(np.uint8) - dataiter = mx.io.ImageRecordIter( + dataiter2 = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar10_mean.bin", shuffle=False, data_shape=(3, 28, 28), batch_size=3, seed_aug=seed_aug) - batch = dataiter.next() - data2 = batch.data[0].asnumpy().astype(np.uint8) - assert(np.array_equal(data,data2)) + + assert_dataiter_items_equals(dataiter1, dataiter2) if __name__ == "__main__": test_NDArrayIter()