Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enabling large tensor support for binary broadcast operators (#16755)
Browse files Browse the repository at this point in the history
* using dim_t as type for dimensions in BinaryBroadcastShape

* add arctan2 and hypot test

* large vector for both

* Revert "using dim_t as type for dimensions in BinaryBroadcastShape"

This reverts commit 3d12ed2.
  • Loading branch information
ChaiBapchya authored and apeforest committed Jan 9, 2020
1 parent 83578b9 commit 6ba9aad
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,17 @@ def test_gather():
assert np.sum(arr[idx[0]] == 2) == SMALL_Y


def test_binary_broadcast():
def check_correctness(mxnet_op, numpy_op, atol=1e-3):
a = mx.nd.ones((LARGE_X, SMALL_Y)).as_np_ndarray()
b = 2*mx.nd.ones((LARGE_X, SMALL_Y)).as_np_ndarray()
res = mxnet_op(a, b)
np_res = numpy_op(1, 2)
assert np.abs(res[-1][-1] - np_res) < atol
check_correctness(mx.np.arctan2, np.arctan2)
check_correctness(mx.np.hypot, np.hypot)


if __name__ == '__main__':
import nose
nose.runmodule()
11 changes: 11 additions & 0 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,17 @@ def test_infer_shape():
assert out_shapes == [(LARGE_X,)]


def test_binary_broadcast():
def check_correctness(mxnet_op, numpy_op, atol=1e-3):
a = mx.nd.ones(LARGE_X).as_np_ndarray()
b = 2*mx.nd.ones(LARGE_X).as_np_ndarray()
res = mxnet_op(a, b)
np_res = numpy_op(1, 2)
assert np.abs(res[-1] - np_res) < atol
check_correctness(mx.np.arctan2, np.arctan2)
check_correctness(mx.np.hypot, np.hypot)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 6ba9aad

Please sign in to comment.