diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index 48b9897aa4c6..898ca050891d 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -70,10 +70,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, if (ishape.ndim() == 1) { if (param.shape.ndim()) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); - if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); + if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, param.shape); } else { SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1)); - if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1)); + if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1)); } return true; } diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 5138728bfe6e..03c9d5c10023 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -276,29 +276,34 @@ def test_parallel_random_seed_setting(): @with_seed() def test_sample_multinomial(): - x = mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0 - dx = mx.nd.ones_like(x) - mx.contrib.autograd.mark_variables([x], [dx]) - # Adding rtol and increasing samples needed to pass with seed 2951820647 - samples = 5000 - with mx.autograd.record(): - y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True) - r = prob * 5 - r.backward() - - y = y.asnumpy() - x = x.asnumpy() - for i in range(x.shape[0]): - - freq = np.bincount(y[i], minlength=5)/np.float32(samples)*x[i].sum() - mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20) - rprob = x[i][y[i]]/x[i].sum() - mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i]) - - real_dx = np.zeros((5,)) - for j in range(samples): - real_dx[y[i][j]] += 5.0 / rprob[j] - mx.test_utils.assert_almost_equal(real_dx, dx.asnumpy()[i], rtol=1e-4) + for x in [mx.nd.array([[0,1,2,3,4],[4,3,2,1,0]])/10.0, mx.nd.array([0,1,2,3,4])/10.0]: + dx = mx.nd.ones_like(x) + mx.contrib.autograd.mark_variables([x], [dx]) + # Adding rtol and increasing samples needed to pass with seed 2951820647 + samples = 5000 + with mx.autograd.record(): + y, prob = mx.nd.random.multinomial(x, shape=samples, get_prob=True) + r = prob * 5 + r.backward() + + y = y.asnumpy() + x = x.asnumpy() + dx = dx.asnumpy() + if len(x.shape) is 1: + x = x.reshape((1, x.shape[0])) + dx = dx.reshape(1, dx.shape[0]) + y = y.reshape((1, y.shape[0])) + prob = prob.reshape((1, prob.shape[0])) + for i in range(x.shape[0]): + freq = np.bincount(y[i,:], minlength=5)/np.float32(samples)*x[i,:].sum() + mx.test_utils.assert_almost_equal(freq, x[i], rtol=0.20) + rprob = x[i][y[i]]/x[i].sum() + mx.test_utils.assert_almost_equal(np.log(rprob), prob.asnumpy()[i], atol=1e-5) + + real_dx = np.zeros((5,)) + for j in range(samples): + real_dx[y[i][j]] += 5.0 / rprob[j] + mx.test_utils.assert_almost_equal(real_dx, dx[i, :], rtol=1e-4, atol=1e-5) # Test the generators with the chi-square testing @with_seed()