diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 2c0841bf1ff8..a0d893d601b3 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1672,10 +1672,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, SMALL_X) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp[0] == 1) == SMALL_Y + assert np.sum(tmp[0].asnumpy() == 1) == SMALL_Y # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx[0]] == 2) == SMALL_Y + assert np.sum(arr[idx[0]].asnumpy() == 2) == SMALL_Y def test_binary_broadcast(): diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 8f01372fcf19..bc87fec33e79 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -1049,10 +1049,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, 10, dtype=np.int64) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp == 1) == 10 + assert np.sum(tmp.asnumpy() == 1) == 10 # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx] == 2) == 10 + assert np.sum(arr[idx].asnumpy() == 2) == 10 def test_infer_shape():