diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 0fec47f1361d..42e4e9395a94 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -151,8 +151,8 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - // Normalized Tensor will be a float - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); return out_attrs->at(0) != -1; }