Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-160] Fix for issue 9062 (#10413)
Browse files Browse the repository at this point in the history
* fix typo in sample_multinomial_op.h (issue #9062)

* add test case for 1D prob input case
  • Loading branch information
haojin2 authored and piiswrong committed Apr 5, 2018
1 parent ed416be commit 4342bd1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
if (ishape.ndim() == 1) {
if (param.shape.ndim()) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, param.shape);
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1));
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1));
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1));
}
return true;
}
Expand Down
51 changes: 28 additions & 23 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,29 +276,34 @@ def test_parallel_random_seed_setting():

@with_seed()
def test_sample_multinomial():
x = mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True)
r = prob * 5
r.backward()

y = y.asnumpy()
x = x.asnumpy()
for i in range(x.shape[0]):

freq = np.bincount(y[i], minlength=5)/np.float32(samples)*x[i].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i]]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i])

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[y[i][j]] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx.asnumpy()[i], rtol=1e-4)
for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]:
dx = mx.nd.ones_like(x)
mx.contrib.autograd.mark_variables([x], [dx])
# Adding rtol and increasing samples needed to pass with seed 2951820647
samples = 5000
with mx.autograd.record():
y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True)
r = prob * 5
r.backward()

y = y.asnumpy()
x = x.asnumpy()
dx = dx.asnumpy()
if len(x.shape) is 1:
x = x.reshape((1, x.shape[0]))
dx = dx.reshape(1, dx.shape[0])
y = y.reshape((1, y.shape[0]))
prob = prob.reshape((1, prob.shape[0]))
for i in range(x.shape[0]):
freq = np.bincount(y[i,:], minlength=5)/np.float32(samples)*x[i,:].sum()
mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20)
rprob = x[i][y[i]]/x[i].sum()
mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5)

real_dx = np.zeros((5,))
for j in range(samples):
real_dx[y[i][j]] += 5.0 / rprob[j]
mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5)

# Test the generators with the chi-square testing
@with_seed()
Expand Down

0 comments on commit 4342bd1

Please sign in to comment.