Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
large array support for randint
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Feb 23, 2019
1 parent 2347017 commit 8818dea
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ struct UniformSampler {
template<typename xpu>
struct SampleRandIntKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lower, const IType *upper, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down
11 changes: 11 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_ndarray_random_uniform():
a = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
assert a[-1][0] != 0

def test_ndarray_random_randint():
a = nd.random.randint(100, 10000, shape=(LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
# check if randint can generate value greater than 2**32 (large)
low_large_value = 2**32
high_large_value = 2**34
a = nd.random.randint(low_large_value,high_large_value)
low = mx.nd.array([low_large_value],dtype='int64')
high = mx.nd.array([high_large_value],dtype='int64')
assert a.__gt__(low) & a.__lt__(high)

def test_ndarray_empty():
a = nd.empty((LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
Expand Down

0 comments on commit 8818dea

Please sign in to comment.