Skip to content

Commit

Permalink
add default behaviour for argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
rondogency committed Dec 24, 2018
1 parent 8413180 commit 64e9f1a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
dmlc::optional<int> axis;
bool keepdims;
DMLC_DECLARE_PARAMETER(ReduceAxisParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>())
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
.describe("The axis along which to perform the reduction. "
"Negative values means indexing from right to left. "
"``Requires axis to be set as int, because global reduction "
"is not supported yet.``");
"``The axis need to be set as an int. If the axis is "
"not set, the rightmost axis will be reduced.``");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
Expand Down
26 changes: 25 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ def test_dot():
C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)


@with_seed()
def test_reduce():
sample_num = 200
Expand Down Expand Up @@ -524,6 +523,31 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes):
keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin),
mx.nd.argmin, False)

@with_seed()
def test_argmax_argmin():
# test optional parameters
# test name : input data, argmax result, argmin result
tests = {
'axis_0' : [[[1, 2, 3], [4, 5, 6]], [1, 1, 1], [0, 0, 0]],
'keep_dims' : [[[1, 2, 3], [4, 5, 6]], [[2], [2]], [[0], [0]]],
'axis_none' : [[1, 2, 3, 4], 3, 0]
}

arg_max = mx.nd.array(tests['axis_0'][0]).argmax(axis=0)
arg_min = mx.nd.array(tests['axis_0'][0]).argmin(axis=0)
assert_almost_equal(arg_max.asnumpy(), tests['axis_0'][1])
assert_almost_equal(arg_min.asnumpy(), tests['axis_0'][2])

arg_max = mx.nd.array(tests['keep_dims'][0]).argmax(axis=1, keepdims=True)
arg_min = mx.nd.array(tests['keep_dims'][0]).argmin(axis=1, keepdims=True)
assert_almost_equal(arg_max.asnumpy(), tests['keep_dims'][1])
assert_almost_equal(arg_min.asnumpy(), tests['keep_dims'][2])

arg_max = mx.nd.array(tests['axis_none'][0]).argmax()
arg_min = mx.nd.array(tests['axis_none'][0]).argmin()
assert_almost_equal(arg_max.asnumpy(), tests['axis_none'][1])
assert_almost_equal(arg_min.asnumpy(), tests['axis_none'][2])

@with_seed()
def test_broadcast():
sample_num = 1000
Expand Down

0 comments on commit 64e9f1a

Please sign in to comment.