Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Julia: fix argmax/argmin for NDArray #13871

Merged
merged 1 commit into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions julia/src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims)
@_remap _prod(x::NDArray, ::Colon) prod(x)
@_remap _prod(x::NDArray, dims) prod(x; axis = 0 .- dims, keepdims = true)

# TODO: support CartesianIndex ?
"""
argmax(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmax`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmin`](@ref mx.argmin).
"""
Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
@_remap _argmax(x::NDArray, ::Colon) argmax(x)
@_remap _argmax(x::NDArray, dims) argmax(x; axis = 0 .- dims, keepdims = true)

"""
argmin(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmin`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmax`](@ref mx.argmax).
"""
Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
@_remap _argmin(x::NDArray, ::Colon) argmin(x)
@_remap _argmin(x::NDArray, dims) argmin(x; axis = 0 .- dims, keepdims = true)

_nddoc[:clip] = _nddoc[:clip!] =
"""
clip(x::NDArray, min, max)
Expand Down Expand Up @@ -1734,6 +1795,10 @@ const _op_import_bl = [ # import black list; do not import these funcs
"broadcast_axis",
"broadcast_axes",
"broadcast_hypot",

# reduction
"argmax",
"argmin",
]

macro _import_ndarray_functions()
Expand Down
46 changes: 46 additions & 0 deletions julia/test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,50 @@ function test_hypot()
@test copy(z) == C
end # function test_hypot

function test_argmax()
@info "NDArray::argmax"
let
A = [1. 5 3;
4 2 6]
x = NDArray(A)

@test copy(argmax(x, dims = 1)) == [2 1 2]
@test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
end

@info "NDArray::argmax::NaN"
let
A = [1. 5 3;
NaN 2 6]
x = NDArray(A)

@test copy(argmax(x, dims = 1)) == [1 1 2]
@test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
end
end

function test_argmin()
@info "NDArray::argmin"
let
A = [1. 5 3;
4 2 6]
x = NDArray(A)

@test copy(argmin(x, dims = 1)) == [1 2 1]
@test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
end

@info "NDArray::argmin::NaN"
let
A = [1. 5 3;
NaN 2 6]
x = NDArray(A)

@test copy(argmin(x, dims = 1)) == [1 2 1]
@test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
end
end

################################################################################
# Run tests
################################################################################
Expand Down Expand Up @@ -1479,6 +1523,8 @@ end # function test_hypot
test_broadcast_to()
test_broadcast_axis()
test_hypot()
test_argmax()
test_argmin()
end

end