diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 598e615f9b71..52fe69bbd434 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8457,7 +8457,7 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_default_zero_dim(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) @@ -8468,7 +8468,7 @@ def test_index_array_default_zero_dim(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_default_zero_size(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) @@ -8492,7 +8492,7 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - @mx.use_np_compat + @mx.use_np_shape def test_index_array_select_axes_zero_size(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) @@ -8502,7 +8502,7 @@ def test_index_array_select_axes_zero_size(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - + test_index_array_default() test_index_array_default_zero_dim() test_index_array_default_zero_size()