diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f36d46af9f01..abee7ade793e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8068,27 +8068,27 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_default_zero_dim(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones(()) - expected = np.zeros((0,)) + input_array = np.ones(()) + expected = np.zeros((0,)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_default_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones((0, 0, 0)) - expected = np.zeros((0, 0, 0, 3)) + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) def test_index_array_select_axes(): shape = (5, 7, 11, 13, 17, 19) @@ -8103,16 +8103,16 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_select_axes_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) - input_array = np.ones((0, 0, 0, 0)) - expected = np.zeros((0, 0, 2)) + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) test_index_array_default() test_index_array_default_zero_dim()