Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Moves seed_aug parameter to ImageRecParserParam and re-seeds RNG befo…
Browse files Browse the repository at this point in the history
…re each augmentation to guarantee reproducibilit
  • Loading branch information
perdasilva authored and Per Goncalves da Silva committed Dec 7, 2018
1 parent 9c0d173 commit fd1e421
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
13 changes: 1 addition & 12 deletions src/io/image_aug_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
int pad;
/*! \brief shape of the image data*/
TShape data_shape;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;

// declare parameters
DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) {
Expand Down Expand Up @@ -188,8 +186,6 @@ struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentPara
DMLC_DECLARE_FIELD(pad).set_default(0)
.describe("Change size from ``[width, height]`` into "
"``[pad + width + pad, pad + height + pad]`` by padding pixes");
DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional<int>())
.describe("Random seed for augmentations.");
}
};

Expand All @@ -208,9 +204,7 @@ std::vector<dmlc::ParamFieldInfo> ListDefaultAugParams() {
class DefaultImageAugmenter : public ImageAugmenter {
public:
// contructor
DefaultImageAugmenter() {
seed_init_state = false;
}
DefaultImageAugmenter() {}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
std::vector<std::pair<std::string, std::string> > kwargs_left;
kwargs_left = param_.InitAllowUnknown(kwargs);
Expand Down Expand Up @@ -250,10 +244,6 @@ class DefaultImageAugmenter : public ImageAugmenter {
}
cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
common::RANDOM_ENGINE *prnd) override {
if (!seed_init_state && param_.seed_aug.has_value()) {
prnd->seed(param_.seed_aug.value());
seed_init_state = true;
}
using mshadow::index_t;
bool is_cropped = false;

Expand Down Expand Up @@ -558,7 +548,6 @@ class DefaultImageAugmenter : public ImageAugmenter {
DefaultImageAugmentParam param_;
/*! \brief list of possible rotate angle */
std::vector<int> rotate_list_;
bool seed_init_state;
};

ImageAugmenter* ImageAugmenter::Create(const std::string& name) {
Expand Down
4 changes: 4 additions & 0 deletions src/io/image_iter_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
size_t shuffle_chunk_size;
/*! \brief the seed for chunk shuffling*/
int shuffle_chunk_seed;
/*! \brief random seed for augmentations */
dmlc::optional<int> seed_aug;

// declare parameters
DMLC_DECLARE_PARAMETER(ImageRecParserParam) {
Expand Down Expand Up @@ -165,6 +167,8 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
.describe("The data shuffle buffer size in MB. Only valid if shuffle is true.");
DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0)
.describe("The random seed for shuffling");
DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional<int>())
.describe("Random seed for augmentations.");
}
};

Expand Down
7 changes: 7 additions & 0 deletions src/io/iter_image_recordio_2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ inline size_t ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t*
cv::Mat res;
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);

// If augmentation seed is supplied
// Re-seed RNG to guarantee reproducible results
if (param_.seed_aug.has_value()) {
prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic);
}

switch (param_.data_shape[0]) {
case 1:
#if MXNET_USE_LIBJPEG_TURBO
Expand Down
14 changes: 8 additions & 6 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def check_CSVIter_synthetic(dtype='float32'):
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)

@unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359")
# @unittest.skip("Flaky test: /~https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3
Expand All @@ -450,7 +450,8 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -469,7 +470,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

# check whether to get different images after change seed_aug
Expand All @@ -490,7 +491,7 @@ def test_ImageRecordIter_seed_augmentation():
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

# check whether seed_aug changes the iterator behavior
Expand All @@ -502,7 +503,8 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data = batch.data[0].asnumpy().astype(np.uint8)
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand All @@ -512,7 +514,7 @@ def test_ImageRecordIter_seed_augmentation():
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[0].asnumpy().astype(np.uint8)
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

if __name__ == "__main__":
Expand Down

0 comments on commit fd1e421

Please sign in to comment.