Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support long for mx.random.seed #14314

Merged
merged 7 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/mxnet/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

import ctypes
from .base import _LIB, check_call
from .base import _LIB, check_call, integer_types
from .ndarray.random import *
from .context import Context

Expand Down Expand Up @@ -90,9 +90,9 @@ def seed(seed_state, ctx="all"):
[[ 2.5020072 -1.6884501]
[-0.7931333 -1.4218881]]
"""
if not isinstance(seed_state, int):
if not isinstance(seed_state, integer_types):
raise ValueError('seed_state must be int')
seed_state = ctypes.c_int(seed_state)
seed_state = ctypes.c_int(int(seed_state))
if ctx == "all":
check_call(_LIB.MXRandomSeed(seed_state))
else:
Expand Down
29 changes: 28 additions & 1 deletion tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_parallel_random_seed_setting():
# Avoid excessive test cpu runtimes
num_temp_seeds = 25 if ctx.device_type == 'gpu' else 1
# To flush out a possible race condition, run multiple times

for _ in range(20):
# Create enough samples such that we get a meaningful distribution.
shape = (200, 200)
Expand Down Expand Up @@ -670,7 +671,7 @@ def gen_data(seed=None):
with random_seed(seed):
python_data = [rnd.random() for _ in range(size)]
np_data = np.random.rand(size)
mx_data = mx.nd.random_uniform(shape=shape, ctx=ctx).asnumpy()
mx_data = mx.random.uniform(shape=shape, ctx=ctx).asnumpy()
return (seed, python_data, np_data, mx_data)

# check data, expecting them to be the same or different based on the seeds
Expand Down Expand Up @@ -712,6 +713,32 @@ def check_data(a, b):
for j in range(i+1, num_seeds):
check_data(data[i],data[j])

@with_seed()
def test_random_seed():
shape = (5, 5)
seed = rnd.randint(-(1 << 31), (1 << 31))

def _assert_same_mx_arrays(a, b):
assert len(a) == len(b)
for a_i, b_i in zip(a, b):
assert (a_i.asnumpy() == b_i.asnumpy()).all()

N = 100
mx.random.seed(seed)
v1 = [mx.random.uniform(shape=shape) for _ in range(N)]

mx.random.seed(seed)
v2 = [mx.random.uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v2)

try:
long
mx.random.seed(long(seed))
v3 = [mx.random.uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v3)
except NameError:
pass

@with_seed()
def test_unique_zipfian_generator():
ctx = mx.context.current_context()
Expand Down