Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Julia: rename mx.clip to clamp for NDArray (#14027)
Browse files Browse the repository at this point in the history
- in order to match Julia `Base.clamp` interface

- depwarn for `mx.clip` included
  • Loading branch information
iblislin authored and wkcn committed Mar 11, 2019
1 parent 47d4d66 commit af41af5
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 35 deletions.
7 changes: 3 additions & 4 deletions julia/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

* Following material from `mx` module got exported (#TBD):
* `NDArray`
* `clip()`
* `clip!()`
* `context()`
* `expand_dims()`
* `@inplace`
Expand Down Expand Up @@ -373,11 +371,12 @@
99.9889 100.533 100.072
```

* Signature of `clip` changed, it doesn't require any keyword argument now.
* Signature of `clip` changed and renamed to `clamp`.
It doesn't require any keyword argument now.
(#TBD)

Before: `clip(x, a_min = -4, a_max = 4)`
After: `clip(x, -4, 4)`
After: `clamp(x, -4, 4)`

### Optimizer

Expand Down
2 changes: 0 additions & 2 deletions julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ export SymbolicNode,

# ndarray.jl
export NDArray,
clip,
clip!,
context,
expand_dims,
@inplace,
Expand Down
8 changes: 6 additions & 2 deletions julia/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ end
@deprecate softmax(x::NDArray; axis = ndims(x)) softmax.(x, axis)
@deprecate log_softmax(x::NDArray; axis = ndims(x)) log_softmax.(x, axis)

@deprecate clip(x; a_min = 0, a_max = 0) clip(x, a_min, a_max)

function broadcast_plus(x::NDArray, y::NDArray)
@warn("broadcast_plus(x, y) is deprecated, use x .+ y instead.")
x .+ y
Expand Down Expand Up @@ -194,3 +192,9 @@ function empty(dims::Int...)
"use `NDArray(undef, dims...)` instead.")
NDArray(undef, dims...)
end

# replaced by Base.clamp
@deprecate clip(x::NDArray, lo::Real, hi::Real) clamp(x, lo, hi)
@deprecate clip!(x::NDArray, lo::Real, hi::Real) clamp!(x, lo, hi)
@deprecate clip(x; a_min = 0, a_max = 0) clamp(x, a_min, a_max)

52 changes: 32 additions & 20 deletions julia/src/ndarray/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,40 +218,52 @@ broadcasted(::typeof(^), x::NDArray{T,N}, y::NDArray{T,N}) where {T,N} =
broadcasted(::typeof(^), x::NDArray{T,N}, y::NDArray{T,M}) where {T,N,M} =
_broadcast_power(x, y)

_nddoc[:clip] = _nddoc[:clip!] =
"""
clip(x::NDArray, min, max)
clip!(x::NDArray, min, max)
clamp(x::NDArray, lo, hi)
Clips (limits) the values in `NDArray`.
Clamps (limits) the values in `NDArray`.
Given an interval, values outside the interval are clipped to the interval edges.
Clipping `x` between `min` and `x` would be:
Clamping `x` between low `lo` and high `hi` would be:
```julia
clip(x, min_, max_) = max(min(x, max_), min_))
clamp(x, lo, hi) = max(min(x, lo), hi))
```
The storage type of clip output depends on storage types of inputs and the
`lo`, `hi` parameter values:
- clamp(default) -> default
- clamp(row_sparse, lo <= 0, hi >= 0) -> row_sparse
- clamp(csr, lo <= 0, hi >= 0) -> csr
- clamp(row_sparse, lo < 0, hi < 0) -> default
- clamp(row_sparse, lo > 0, hi > 0) -> default
- clamp(csr, lo < 0, hi < 0) -> csr
- clamp(csr, lo > 0, hi > 0) -> csr
## Examples
```jldoctest
julia> x = NDArray(1:9);
julia> mx.clip(x, 2, 8)'
julia> clamp(x, 2, 8)'
1×9 mx.NDArray{Int64,2} @ CPU0:
2 2 3 4 5 6 7 8 8
```
The storage type of clip output depends on storage types of inputs and the
`min`, `max` parameter values:
- clip(default) = default
- clip(row_sparse, min <= 0, max >= 0) = row_sparse
- clip(csr, min <= 0, max >= 0) = csr
- clip(row_sparse, min < 0, max < 0) = default
- clip(row_sparse, min > 0, max > 0) = default
- clip(csr, min < 0, max < 0) = csr
- clip(csr, min > 0, max > 0) = csr
julia> clamp(x, 8, 2)'
1×9 NDArray{Int64,2} @ CPU0:
8 8 2 2 2 2 2 2 2
```
"""
Base.clamp(x::NDArray, lo::Real, hi::Real) = _clamp(x, lo, hi)
@_remap _clamp(x::NDArray, lo::Real, hi::Real) clip(x; a_min = lo, a_max = hi)

"""
clamp!(x::NDArray, lo, hi)
See also [`clamp`](@ref).
"""
@_remap clip(x::NDArray, min::Real, max::Real) clip(x; a_min = min, a_max = max)
@_remap clip!(x::NDArray, min::Real, max::Real) clip(x; a_min = min, a_max = max)
Base.clamp!(x::NDArray, lo::Real, hi::Real) = _clamp!(x, lo, hi)
@_remap _clamp!(x::NDArray, lo::Real, hi::Real) clip(x; a_min = lo, a_max = hi)

################################################################################
# remapping to solving type unstablility
Expand Down
12 changes: 12 additions & 0 deletions julia/src/ndarray/remap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ function _docsig(fname::Symbol, sig::Expr, opname::String)
end
end

"""
@_remap(sig::Expr, imp::Expr)
Creating a function in signature `sig` with the function implementation `imp`.
## Arguments
- `sig` is the function signature.
If the function name ends with `!`, it will invoke the corresponding inplace
call.
- `imp` is the underlying libmxnet API call
"""
macro _remap(sig::Expr, imp::Expr)
d = splitdef(:($sig = $imp))
@capture d[:name] (M_.fname_|fname_)
Expand Down
2 changes: 1 addition & 1 deletion julia/src/optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function normgrad!(opt::AbstractOptimizer, W::NDArray, ∇::NDArray)
!iszero(s) && @inplace.*= s
# gradient clipping
c = opt.clip
c > 0 && clip!(∇, -c, c)
c > 0 && clamp!(∇, -c, c)
# weight decay
λ = opt.λ
λ > 0 && @inplace+= λ .* W
Expand Down
12 changes: 6 additions & 6 deletions julia/test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,24 +885,24 @@ function test_saveload()
rm(fname)
end

function test_clip()
function test_clamp()
dims = rand_dims()
@info("NDArray::clip::dims = $dims")
@info("NDArray::clamp::dims = $dims")

j_array, nd_array = rand_tensors(dims)
clip_up = maximum(abs.(j_array)) / 2
clip_down = 0
clipped = clip(nd_array, clip_down, clip_up)
clipped = clamp(nd_array, clip_down, clip_up)

# make sure the original array is not modified
@test copy(nd_array) j_array

@test all(clip_down .<= copy(clipped) .<= clip_up)

@info("NDArray::clip!")
@info("NDArray::clamp!")
let
x = NDArray(1.0:20)
clip!(x, 5, 15)
clamp!(x, 5, 15)
@test all(5 .<= copy(x) .<= 15)
end
end
Expand Down Expand Up @@ -1571,7 +1571,7 @@ end
test_mod()
test_gd()
test_saveload()
test_clip()
test_clamp()
test_power()
test_sqrt()
test_eltype()
Expand Down

0 comments on commit af41af5

Please sign in to comment.