diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py index efd6ef36744f..c8bf01f48ca3 100644 --- a/tests/python/unittest/test_init.py +++ b/tests/python/unittest/test_init.py @@ -60,8 +60,17 @@ def check_rsp_const_init(init, val): check_rsp_const_init(mx.initializer.Zero(), 0.) check_rsp_const_init(mx.initializer.One(), 1.) +def test_bilinear_init(): + bili = mx.init.Bilinear() + bili_weight = mx.ndarray.empty((1,1,4,4)) + bili._init_weight(None, bili_weight) + bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]]) + bili_2d = bili_1d * np.transpose(bili_1d) + assert (bili_2d == bili_weight.asnumpy()).all() + if __name__ == '__main__': test_variable_init() test_default_init() test_aux_init() test_rsp_const_init() + test_bilinear_init()