diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index b67ab624ac70..611592aa4d82 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -217,7 +217,7 @@ def _init_bilinear(self, _, arr): c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] - y = (i / shape[3]) % shape[2] + y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) @@ -657,7 +657,7 @@ def _init_weight(self, _, arr): c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] - y = (i / shape[3]) % shape[2] + y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py index efd6ef36744f..c8bf01f48ca3 100644 --- a/tests/python/unittest/test_init.py +++ b/tests/python/unittest/test_init.py @@ -60,8 +60,17 @@ def check_rsp_const_init(init, val): check_rsp_const_init(mx.initializer.Zero(), 0.) check_rsp_const_init(mx.initializer.One(), 1.) +def test_bilinear_init(): + bili = mx.init.Bilinear() + bili_weight = mx.ndarray.empty((1,1,4,4)) + bili._init_weight(None, bili_weight) + bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]]) + bili_2d = bili_1d * np.transpose(bili_1d) + assert (bili_2d == bili_weight.asnumpy()).all() + if __name__ == '__main__': test_variable_init() test_default_init() test_aux_init() test_rsp_const_init() + test_bilinear_init()