Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update image record iterator tests to check the whole iterator not on…
Browse files Browse the repository at this point in the history
…ly first image
  • Loading branch information
perdasilva committed Nov 23, 2018
1 parent e2fe3fb commit c92ccba
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import mxnet.ndarray as nd
from mxnet.test_utils import *
from mxnet.base import MXNetError
import itertools
import numpy as np
import os
import gzip
Expand All @@ -32,6 +33,12 @@
import sys
from common import assertRaises
import unittest
import sys

if sys.version_info >= (3,0):
from itertools import zip_longest
else:
from itertools import izip_longest as zip_longest


def test_MNISTIter():
Expand Down Expand Up @@ -427,13 +434,43 @@ def check_CSVIter_synthetic(dtype='float32'):
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)

# @unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3

def assert_dataiter_equals(dataiter1, dataiter2):
for batch1, batch2 in zip_longest(dataiter1, dataiter2):
# ensure iterators are of same length
assert(batch1 and batch2)

# ensure batches are of same length
assert(len(batch1.data) == len(batch2.data))

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
assert(np.array_equal(data1, data2))

def assert_dataiter_not_equals(dataiter1, dataiter2):
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# try to ensure iterators are of same length
assert(batch1 and batch2)

# ensure batches are of same length
assert(len(batch1.data) == len(batch2.data))

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
if not np.array_equal(data1, data2):
return
assert False, 'Expected data iterators to be different, but they are the same'

# check whether to get constant images after fixing seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -449,11 +486,8 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -469,12 +503,12 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_equals(dataiter1, dataiter2)

# check whether to get different images after change seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1.reset()
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -490,32 +524,27 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

assert_dataiter_not_equals(dataiter1, dataiter2)

# check whether seed_aug changes the iterator behavior
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_equals(dataiter1, dataiter2)

if __name__ == "__main__":
test_NDArrayIter()
Expand Down

0 comments on commit c92ccba

Please sign in to comment.