diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index ff3f794d167d..5a653d8a175c 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -50,11 +50,17 @@ namespace op { struct BilinearSampleParam : public dmlc::Parameter { int height; int width; + dmlc::optional scale_height; + dmlc::optional scale_width; DMLC_DECLARE_PARAMETER(BilinearSampleParam) { - DMLC_DECLARE_FIELD(height).set_range(1, 10000) - .describe("output height (required)"); - DMLC_DECLARE_FIELD(width).set_range(1, 10000) - .describe("output width (required)"); + DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000) + .describe("output height (required, but ignored if scale_height is defined)"); + DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000) + .describe("output width (required, but ignored if scale_width is defined)"); + DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional()) + .describe("sampling scale of the height (optional, ignores height if defined)"); + DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional()) + .describe("sampling scale of the scale_width (optional, ignores width if defined)"); } }; @@ -129,8 +135,18 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, const BilinearSampleParam& param = nnvm::get(attrs.parsed); TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; - dshape[2] = param.height; - dshape[3] = param.width; + if (param.scale_height.has_value()) { + dshape[2] = static_cast(param.scale_height.value() * in_shape->at(0)[2]); + } else { + dshape[2] = param.height; + } + + if (param.scale_height.has_value()) { + dshape[3] = static_cast(param.scale_width.value() * in_shape->at(0)[3]); + } else { + dshape[3] = param.width; + } + out_shape->clear(); out_shape->push_back(dshape); return true; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 67aeddf19c44..3f34ade448dc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6533,6 +6533,11 @@ def check_bilinear_resize_op(shape, height, width): x = mx.nd.random.uniform(shape=shape) y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width) assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) + + x_scale = width / shape[-1] + y_scale = height / shape[-2] + y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale) + assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) shape = (2, 2, 10, 10) check_bilinear_resize_op(shape, 5, 5) check_bilinear_resize_op(shape, 10, 10)