Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[API] Add bitwise_left/right_shift (#20587)
Browse files Browse the repository at this point in the history
* add bitwise_left/right_shift

* add more methods

* add mshadow_op.h

* fix

* fix lint & add tests

* fix

* update operator_tune.cc

* update amp list

* add rtc functions

* fix bitwise rtc functions & numpy op gpu test overriding issue

* clang-format

* fix ci

* add int16 support

* add MXNET_INT_TYPE_SWITCH_EXT

* fix sanity check

* fix lint

* fix

* fix lint
  • Loading branch information
barry-jin authored Nov 1, 2021
1 parent 3dffdc1 commit 943ab64
Show file tree
Hide file tree
Showing 16 changed files with 749 additions and 5 deletions.
6 changes: 6 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@
'_npi_bitwise_or_scalar',
'_npi_bitwise_xor',
'_npi_bitwise_xor_scalar',
'_npi_bitwise_left_shift',
'_npi_bitwise_left_shift_scalar',
'_npi_bitwise_right_shift',
'_npi_bitwise_right_shift_scalar',
'_npi_rbitwise_left_shift_scalar',
'_npi_rbitwise_right_shift_scalar',
'_npi_blackman',
'_npi_boolean_mask_assign_scalar',
'_npi_boolean_mask_assign_tensor',
Expand Down
80 changes: 79 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'where', 'bincount', 'rollaxis', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide']
'positive', 'logaddexp', 'floor_divide', 'bitwise_left_shift', 'bitwise_right_shift']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -10015,3 +10015,81 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N
raise ValueError("only where=None or where=True cases are supported for now")
return _api_internal.sum(a, axis, dtype, keepdims, initial, out)
# pylint:enable=redefined-outer-name, too-many-arguments


@set_module('mxnet.ndarray.numpy')
def bitwise_left_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the left. Bits are shifted to the left by
appending x2 0s at the right of x1. Since the internal representation of numbers
is in binary format, this operation is equivalent to ``x1 * 2**x2``
Parameters
----------
x1 : ndarray or scalar
Input values.
x2 : ndarray or scalar
Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.binary_repr(5)
'101'
>>> np.left_shift(5, 2)
20
>>> np.binary_repr(20)
'10100'
>>> np.left_shift(5, np.array([1,2,3]))
array([10, 20, 40])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.left_shift(x1, x2, out=out)
return _api_internal.bitwise_left_shift(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
def bitwise_right_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the right. Bits are shifted to the right by
x2. Because the internal representation of numbers is in binary format,
this operation is equivalent to ``x1 / 2**x2``
Parameters
----------
x1 : ndarray or scalar
Input values.
x1 : ndarray or scalar
Number of bits to remove at the right of x1. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.binary_repr(10)
'1010'
>>> np.right_shift(10, 1)
5
>>> np.binary_repr(5)
'101'
>>> np.right_shift(10, np.array([1,2,3]))
array([5, 2, 1])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.right_shift(x1, x2, out=out)
return _api_internal.bitwise_right_shift(x1, x2, out)
104 changes: 103 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide', 'permute_dims']
'positive', 'logaddexp', 'floor_divide', 'permute_dims', 'bitwise_left_shift', 'bitwise_right_shift']

__all__ += fallback.__all__

Expand Down Expand Up @@ -1057,6 +1057,16 @@ def __rxor__(self, other):
"""x.__rxor__(y) <=> y ^ x"""
return bitwise_xor(other, self)

@wrap_mxnp_np_ufunc
def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
return bitwise_left_shift(self, other)

@wrap_mxnp_np_ufunc
def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
return bitwise_right_shift(self, other)

@wrap_mxnp_np_ufunc
def __iand__(self, other):
"""x.__iand__(y) <=> x &= y"""
Expand All @@ -1072,6 +1082,26 @@ def __ixor__(self, other):
"""x.__ixor__(y) <=> x ^= y"""
return bitwise_xor(self, other, out=self)

@wrap_mxnp_np_ufunc
def __ilshift__(self, other):
"""x.__ilshift__(y) <=> x <<= y"""
return bitwise_left_shift(self, other, out=self)

@wrap_mxnp_np_ufunc
def __irshift__(self, other):
"""x.__irshift__(y) <=> x >>= y"""
return bitwise_right_shift(self, other, out=self)

@wrap_mxnp_np_ufunc
def __rlshift__(self, other):
"""x.__rlshift__(y) <=> y << x"""
return bitwise_left_shift(other, self)

@wrap_mxnp_np_ufunc
def __rrshift__(self, other):
"""x.__rrshift__(y) <=> y >> x"""
return bitwise_right_shift(other, self)

def __round__(self, n=0):
"""x.__round__(n)"""
return round(self, decimals=n)
Expand Down Expand Up @@ -13033,3 +13063,75 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N
"""
return _mx_nd_np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
# pylint: enable=redefined-outer-name, too-many-arguments


@set_module('mxnet.numpy')
def bitwise_left_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the left. Bits are shifted to the left by
appending x2 0s at the right of x1. Since the internal representation of numbers
is in binary format, this operation is equivalent to ``x1 * 2**x2``
Parameters
----------
x1 : ndarray or scalar
Input values.
x2 : ndarray or scalar
Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.binary_repr(5)
'101'
>>> np.left_shift(5, 2)
20
>>> np.binary_repr(20)
'10100'
"""
return _mx_nd_np.bitwise_left_shift(x1, x2, out)


@set_module('mxnet.numpy')
def bitwise_right_shift(x1, x2, out=None):
r"""
Shift the bits of and integer to the right. Bits are shifted to the right by
x2. Because the internal representation of numbers is in binary format,
this operation is equivalent to ``x1 / 2**x2``
Parameters
----------
x1 : ndarray or scalar
Input values.
x1 : ndarray or scalar
Number of bits to remove at the right of x1. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.binary_repr(10)
'1010'
>>> np.right_shift(10, 1)
5
>>> np.binary_repr(5)
'101'
>>> np.right_shift(10, np.array([1,2,3]))
array([5, 2, 1])
"""
return _mx_nd_np.bitwise_right_shift(x1, x2, out)
18 changes: 18 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,22 @@ MXNET_REGISTER_API("_npi.ldexp").set_body([](runtime::MXNetArgs args, runtime::M
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.bitwise_left_shift")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_left_shift");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_left_shift_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_left_shift_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.bitwise_right_shift")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_bitwise_right_shift");
const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_right_shift_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_right_shift_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet
44 changes: 44 additions & 0 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ copysign_grad(const DType val,
return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_left_shift_grad(const DType val,
const DType2 val2) {
return op::power(static_cast<DType>(2), val2);
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_left_shift_rgrad(const DType val,
const DType2 val2) {
using type = mixed_type<DType, DType2>;
return val * op::power(static_cast<DType>(2), val2) * op::log(static_cast<type>(2));
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rbitwise_left_shift_grad(const DType val,
const DType2 val2) {
using type = mixed_type<DType, DType2>;
return val2 * op::power(static_cast<DType>(2), val) * op::log(static_cast<type>(2));
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_right_shift_grad(const DType val,
const DType2 val2) {
return op::power(0.5f, val2);
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
bitwise_right_shift_rgrad(const DType val,
const DType2 val2) {
return val * op::power(0.5f, val2) * op::log(0.5f);
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rbitwise_right_shift_grad(const DType val,
const DType2 val2) {
return val2 * op::power(0.5f, val) * op::log(0.5f);
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
arctan2_grad(const DType val,
Expand Down
32 changes: 32 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,38 @@ __device__ inline mixed_type<DType, DType2> bitwise_and(const DType a,
return real_a & real_b;
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> bitwise_left_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_a << real_b;
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> rbitwise_left_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_b << real_a;
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> bitwise_right_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_a >> real_b;
}
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2> rbitwise_right_shift(const DType a,
const DType2 b) {
const mixed_type<DType, DType2> real_a = a;
const mixed_type<DType, DType2> real_b = b;
return real_b >> real_a;
}
DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
template <typename DType, typename DType2>
Expand Down
2 changes: 1 addition & 1 deletion src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc::

ThreadedOpr* threaded_opr = static_cast<OprBlock*>(info->opr_block)->opr;
auto* event_pool = static_cast<CUDAEventPool*>(info->event_pool);
auto [event, event_pool_idx] = event_pool->GetNextEvent();
auto [event, event_pool_idx] = event_pool->GetNextEvent(); // NOLINT(*)
auto ev = event.lock();
MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_));
for (auto* read_var : threaded_opr->const_vars) {
Expand Down
38 changes: 38 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,44 @@ MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>

MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));

#pragma GCC diagnostic push
#if __GNUC__ >= 7
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif

/*! \brief used for generate element of bitwise_left_shift */
MXNET_BINARY_MATH_OP(bitwise_left_shift, static_cast<int64_t>(a) << static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_left_shift_grad, math::pow(2.0f, static_cast<int64_t>(b)));

MXNET_BINARY_MATH_OP(bitwise_left_shift_rgrad,
static_cast<int64_t>(a) * math::pow(2.0f, static_cast<int64_t>(b)) *
math::log(2.0f));

MXNET_BINARY_MATH_OP(rbitwise_left_shift, static_cast<int64_t>(b) << static_cast<int64_t>(a));

MXNET_BINARY_MATH_OP(rbitwise_left_shift_grad,
static_cast<int64_t>(b) * math::pow(2.0f, static_cast<int64_t>(a)) *
math::log(2.0f));

/*! \brief used for generate element of bitwise_right_shift */
MXNET_BINARY_MATH_OP(bitwise_right_shift, static_cast<int64_t>(a) >> static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_right_shift_grad, math::pow(0.5f, static_cast<int64_t>(b)));

MXNET_BINARY_MATH_OP(bitwise_right_shift_rgrad,
static_cast<int64_t>(a) * math::pow(0.5f, static_cast<int64_t>(b)) *
math::log(0.5f));

MXNET_BINARY_MATH_OP(rbitwise_right_shift, static_cast<int64_t>(b) >> static_cast<int64_t>(a));

MXNET_BINARY_MATH_OP(rbitwise_right_shift_grad,
static_cast<int64_t>(b) * math::pow(0.5f, static_cast<int64_t>(a)) *
math::log(0.5f));

#pragma GCC diagnostic pop

MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
Expand Down
Loading

0 comments on commit 943ab64

Please sign in to comment.