From 943ab644b1d33d52fa0d224100c743505eb7e0eb Mon Sep 17 00:00:00 2001 From: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> Date: Mon, 1 Nov 2021 09:10:22 -0700 Subject: [PATCH] [API] Add bitwise_left/right_shift (#20587) * add bitwise_left/right_shift * add more methods * add mshadow_op.h * fix * fix lint & add tests * fix * update operator_tune.cc * update amp list * add rtc functions * fix bitwise rtc functions & numpy op gpu test overriding issue * clang-format * fix ci * add int16 support * add MXNET_INT_TYPE_SWITCH_EXT * fix sanity check * fix lint * fix * fix lint --- python/mxnet/amp/lists/symbol_fp16.py | 6 + python/mxnet/ndarray/numpy/_op.py | 80 ++++++- python/mxnet/numpy/multiarray.py | 104 ++++++++- .../numpy/np_elemwise_broadcast_op.cc | 18 ++ src/common/cuda/rtc/backward_functions-inl.h | 44 ++++ src/common/cuda/rtc/forward_functions-inl.h | 32 +++ src/engine/threaded_engine.cc | 2 +- src/operator/mshadow_op.h | 38 ++++ src/operator/mxnet_op.h | 54 +++++ .../np_elemwise_broadcast_op_extended_thi.cc | 212 ++++++++++++++++++ .../np_elemwise_broadcast_op_extended_thi.cu | 72 ++++++ src/operator/operator_tune.cc | 10 + .../tensor/elemwise_binary_broadcast_op.h | 2 +- src/operator/tensor/elemwise_binary_op.h | 2 +- .../unittest/test_numpy_interoperability.py | 26 +++ tests/python/unittest/test_numpy_op.py | 52 +++++ 16 files changed, 749 insertions(+), 5 deletions(-) create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cu diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index e54b523e1fe6..a1404d512834 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -195,6 +195,12 @@ '_npi_bitwise_or_scalar', '_npi_bitwise_xor', '_npi_bitwise_xor_scalar', + '_npi_bitwise_left_shift', + '_npi_bitwise_left_shift_scalar', + '_npi_bitwise_right_shift', + '_npi_bitwise_right_shift_scalar', + '_npi_rbitwise_left_shift_scalar', + '_npi_rbitwise_right_shift_scalar', '_npi_blackman', '_npi_boolean_mask_assign_scalar', '_npi_boolean_mask_assign_tensor', diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index b0471c416e14..7378bd6e7a8a 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -51,7 +51,7 @@ 'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze', 'where', 'bincount', 'rollaxis', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'diag', 'diagonal', - 'positive', 'logaddexp', 'floor_divide'] + 'positive', 'logaddexp', 'floor_divide', 'bitwise_left_shift', 'bitwise_right_shift'] @set_module('mxnet.ndarray.numpy') @@ -10015,3 +10015,81 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N raise ValueError("only where=None or where=True cases are supported for now") return _api_internal.sum(a, axis, dtype, keepdims, initial, out) # pylint:enable=redefined-outer-name, too-many-arguments + + +@set_module('mxnet.ndarray.numpy') +def bitwise_left_shift(x1, x2, out=None): + r""" + Shift the bits of and integer to the left. Bits are shifted to the left by + appending x2 0s at the right of x1. Since the internal representation of numbers + is in binary format, this operation is equivalent to ``x1 * 2**x2`` + + Parameters + ---------- + x1 : ndarray or scalar + Input values. + x2 : ndarray or scalar + Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.binary_repr(5) + '101' + >>> np.left_shift(5, 2) + 20 + >>> np.binary_repr(20) + '10100' + >>> np.left_shift(5, np.array([1,2,3])) + array([10, 20, 40]) + """ + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.left_shift(x1, x2, out=out) + return _api_internal.bitwise_left_shift(x1, x2, out) + + +@set_module('mxnet.ndarray.numpy') +def bitwise_right_shift(x1, x2, out=None): + r""" + Shift the bits of and integer to the right. Bits are shifted to the right by + x2. Because the internal representation of numbers is in binary format, + this operation is equivalent to ``x1 / 2**x2`` + + Parameters + ---------- + x1 : ndarray or scalar + Input values. + x1 : ndarray or scalar + Number of bits to remove at the right of x1. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.binary_repr(10) + '1010' + >>> np.right_shift(10, 1) + 5 + >>> np.binary_repr(5) + '101' + >>> np.right_shift(10, np.array([1,2,3])) + array([5, 2, 1]) + """ + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.right_shift(x1, x2, out=out) + return _api_internal.bitwise_right_shift(x1, x2, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 04bace50da27..e8fa2a788c2d 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -81,7 +81,7 @@ 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze', 'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal', - 'positive', 'logaddexp', 'floor_divide', 'permute_dims'] + 'positive', 'logaddexp', 'floor_divide', 'permute_dims', 'bitwise_left_shift', 'bitwise_right_shift'] __all__ += fallback.__all__ @@ -1057,6 +1057,16 @@ def __rxor__(self, other): """x.__rxor__(y) <=> y ^ x""" return bitwise_xor(other, self) + @wrap_mxnp_np_ufunc + def __lshift__(self, other): + """x.__lshift__(y) <=> x << y""" + return bitwise_left_shift(self, other) + + @wrap_mxnp_np_ufunc + def __rshift__(self, other): + """x.__rshift__(y) <=> x >> y""" + return bitwise_right_shift(self, other) + @wrap_mxnp_np_ufunc def __iand__(self, other): """x.__iand__(y) <=> x &= y""" @@ -1072,6 +1082,26 @@ def __ixor__(self, other): """x.__ixor__(y) <=> x ^= y""" return bitwise_xor(self, other, out=self) + @wrap_mxnp_np_ufunc + def __ilshift__(self, other): + """x.__ilshift__(y) <=> x <<= y""" + return bitwise_left_shift(self, other, out=self) + + @wrap_mxnp_np_ufunc + def __irshift__(self, other): + """x.__irshift__(y) <=> x >>= y""" + return bitwise_right_shift(self, other, out=self) + + @wrap_mxnp_np_ufunc + def __rlshift__(self, other): + """x.__rlshift__(y) <=> y << x""" + return bitwise_left_shift(other, self) + + @wrap_mxnp_np_ufunc + def __rrshift__(self, other): + """x.__rrshift__(y) <=> y >> x""" + return bitwise_right_shift(other, self) + def __round__(self, n=0): """x.__round__(n)""" return round(self, decimals=n) @@ -13033,3 +13063,75 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N """ return _mx_nd_np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) # pylint: enable=redefined-outer-name, too-many-arguments + + +@set_module('mxnet.numpy') +def bitwise_left_shift(x1, x2, out=None): + r""" + Shift the bits of and integer to the left. Bits are shifted to the left by + appending x2 0s at the right of x1. Since the internal representation of numbers + is in binary format, this operation is equivalent to ``x1 * 2**x2`` + + Parameters + ---------- + x1 : ndarray or scalar + Input values. + x2 : ndarray or scalar + Number of zeros to append to x1. Has to be non-negative. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.binary_repr(5) + '101' + >>> np.left_shift(5, 2) + 20 + >>> np.binary_repr(20) + '10100' + """ + return _mx_nd_np.bitwise_left_shift(x1, x2, out) + + +@set_module('mxnet.numpy') +def bitwise_right_shift(x1, x2, out=None): + r""" + Shift the bits of and integer to the right. Bits are shifted to the right by + x2. Because the internal representation of numbers is in binary format, + this operation is equivalent to ``x1 / 2**x2`` + + Parameters + ---------- + x1 : ndarray or scalar + Input values. + x1 : ndarray or scalar + Number of bits to remove at the right of x1. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.binary_repr(10) + '1010' + >>> np.right_shift(10, 1) + 5 + >>> np.binary_repr(5) + '101' + >>> np.right_shift(10, np.array([1,2,3])) + array([5, 2, 1]) + """ + return _mx_nd_np.bitwise_right_shift(x1, x2, out) diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc index 067d419c3cdb..c1f8b7e4b1b0 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc @@ -189,4 +189,22 @@ MXNET_REGISTER_API("_npi.ldexp").set_body([](runtime::MXNetArgs args, runtime::M UFuncHelper(args, ret, op, op_scalar, op_rscalar); }); +MXNET_REGISTER_API("_npi.bitwise_left_shift") + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bitwise_left_shift"); + const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_left_shift_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_left_shift_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); + +MXNET_REGISTER_API("_npi.bitwise_right_shift") + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bitwise_right_shift"); + const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_right_shift_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rbitwise_right_shift_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); + } // namespace mxnet diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index d28ce734c14e..c583df373bb9 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -426,6 +426,50 @@ copysign_grad(const DType val, return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1; } +template +__device__ inline mixed_type +bitwise_left_shift_grad(const DType val, + const DType2 val2) { + return op::power(static_cast(2), val2); +} + +template +__device__ inline mixed_type +bitwise_left_shift_rgrad(const DType val, + const DType2 val2) { + using type = mixed_type; + return val * op::power(static_cast(2), val2) * op::log(static_cast(2)); +} + +template +__device__ inline mixed_type +rbitwise_left_shift_grad(const DType val, + const DType2 val2) { + using type = mixed_type; + return val2 * op::power(static_cast(2), val) * op::log(static_cast(2)); +} + +template +__device__ inline mixed_type +bitwise_right_shift_grad(const DType val, + const DType2 val2) { + return op::power(0.5f, val2); +} + +template +__device__ inline mixed_type +bitwise_right_shift_rgrad(const DType val, + const DType2 val2) { + return val * op::power(0.5f, val2) * op::log(0.5f); +} + +template +__device__ inline mixed_type +rbitwise_right_shift_grad(const DType val, + const DType2 val2) { + return val2 * op::power(0.5f, val) * op::log(0.5f); +} + template __device__ inline mixed_type arctan2_grad(const DType val, diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index 2b457092b3c8..f6a87f401ba6 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -617,6 +617,38 @@ __device__ inline mixed_type bitwise_and(const DType a, return real_a & real_b; } +template +__device__ inline mixed_type bitwise_left_shift(const DType a, + const DType2 b) { + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a << real_b; +} + +template +__device__ inline mixed_type rbitwise_left_shift(const DType a, + const DType2 b) { + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_b << real_a; +} + +template +__device__ inline mixed_type bitwise_right_shift(const DType a, + const DType2 b) { + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_a >> real_b; +} + +template +__device__ inline mixed_type rbitwise_right_shift(const DType a, + const DType2 b) { + const mixed_type real_a = a; + const mixed_type real_b = b; + return real_b >> real_a; +} + DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f) template diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 40d852b83b86..7639fd445987 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -712,7 +712,7 @@ void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc:: ThreadedOpr* threaded_opr = static_cast(info->opr_block)->opr; auto* event_pool = static_cast(info->event_pool); - auto [event, event_pool_idx] = event_pool->GetNextEvent(); + auto [event, event_pool_idx] = event_pool->GetNextEvent(); // NOLINT(*) auto ev = event.lock(); MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_)); for (auto* read_var : threaded_opr->const_vars) { diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 34f852ddaa02..9a14794a47da 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -812,6 +812,44 @@ MXNET_BINARY_MATH_OP(bitwise_xor, static_cast(a) ^ static_cast MXNET_BINARY_MATH_OP(bitwise_or, static_cast(a) | static_cast(b)); +#pragma GCC diagnostic push +#if __GNUC__ >= 7 +#pragma GCC diagnostic ignored "-Wint-in-bool-context" +#pragma GCC diagnostic ignored "-Wbool-compare" +#endif + +/*! \brief used for generate element of bitwise_left_shift */ +MXNET_BINARY_MATH_OP(bitwise_left_shift, static_cast(a) << static_cast(b)); + +MXNET_BINARY_MATH_OP(bitwise_left_shift_grad, math::pow(2.0f, static_cast(b))); + +MXNET_BINARY_MATH_OP(bitwise_left_shift_rgrad, + static_cast(a) * math::pow(2.0f, static_cast(b)) * + math::log(2.0f)); + +MXNET_BINARY_MATH_OP(rbitwise_left_shift, static_cast(b) << static_cast(a)); + +MXNET_BINARY_MATH_OP(rbitwise_left_shift_grad, + static_cast(b) * math::pow(2.0f, static_cast(a)) * + math::log(2.0f)); + +/*! \brief used for generate element of bitwise_right_shift */ +MXNET_BINARY_MATH_OP(bitwise_right_shift, static_cast(a) >> static_cast(b)); + +MXNET_BINARY_MATH_OP(bitwise_right_shift_grad, math::pow(0.5f, static_cast(b))); + +MXNET_BINARY_MATH_OP(bitwise_right_shift_rgrad, + static_cast(a) * math::pow(0.5f, static_cast(b)) * + math::log(0.5f)); + +MXNET_BINARY_MATH_OP(rbitwise_right_shift, static_cast(b) >> static_cast(a)); + +MXNET_BINARY_MATH_OP(rbitwise_right_shift_grad, + static_cast(b) * math::pow(0.5f, static_cast(a)) * + math::log(0.5f)); + +#pragma GCC diagnostic pop + MXNET_UNARY_MATH_OP(square_root, math::sqrt(a)); MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a)); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index fd3b5877224e..09e42481a66b 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -419,6 +419,60 @@ struct AccType { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT_TYPE_SWITCH_EXT(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } break; \ + case mshadow::kFloat64: { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } break; \ + case mshadow::kFloat16: { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } break; \ + case mshadow::kUint8: { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt8: { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt32: { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt64: { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kInt16: { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint16: { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint32: { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kUint64: { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } break; \ + case mshadow::kBool: { \ + typedef bool DType; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + #define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: { \ diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc new file mode 100644 index 000000000000..90ecd6e2387a --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_broadcast_op_extended_thi.cc + * \brief CPU Implementation of extended functions for elementwise numpy binary broadcast operator. + * (Third extended file) + */ + +#include "../../common/utils.h" +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr( \ + "FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) + +NNVM_REGISTER_OP(_npi_bitwise_left_shift) + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) + .set_attr("FInferShape", BinaryBroadcastShape) + .set_attr("FInferType", ElemwiseIntType<2, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}, {1, 0}}; + }) + .set_attr("FCompute", + BinaryBroadcastCompute) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_left_shift"}) + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_bitwise_left_shift_scalar) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", ElemwiseShape<1, 1>) + .set_attr("FInferType", ElemwiseIntType<1, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) + .set_attr("FCompute", + BinaryScalarOp::Compute) + .set_attr("FGradient", + ElemwiseGradUseIn{"_backward_npi_bitwise_left_shift_scalar"}) + .add_argument("data", "NDArray-or-Symbol", "source input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_rbitwise_left_shift_scalar) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", ElemwiseShape<1, 1>) + .set_attr("FInferType", ElemwiseIntType<1, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) + .set_attr("FCompute", + BinaryScalarOp::Compute) + .set_attr("FGradient", + ElemwiseGradUseIn{"_backward_npi_rbitwise_left_shift_scalar"}) + .add_argument("data", "NDArray-or-Symbol", "source input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_bitwise_left_shift) + .set_num_inputs(3) + .set_num_outputs(2) + .set_attr("TIsBackward", true) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 1}}; + }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FCompute", + BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_backward_npi_bitwise_left_shift_scalar) + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rbitwise_left_shift_scalar) + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_bitwise_right_shift) + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) + .set_attr("FInferShape", BinaryBroadcastShape) + .set_attr("FInferType", ElemwiseIntType<2, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}, {1, 0}}; + }) + .set_attr("FCompute", + BinaryBroadcastCompute) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_bitwise_right_shift"}) + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_bitwise_right_shift_scalar) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", ElemwiseShape<1, 1>) + .set_attr("FInferType", ElemwiseIntType<1, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) + .set_attr("FCompute", + BinaryScalarOp::Compute) + .set_attr("FGradient", + ElemwiseGradUseIn{"_backward_npi_bitwise_right_shift_scalar"}) + .add_argument("data", "NDArray-or-Symbol", "source input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_rbitwise_right_shift_scalar) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", ElemwiseShape<1, 1>) + .set_attr("FInferType", ElemwiseIntType<1, 1>) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) + .set_attr("FCompute", + BinaryScalarOp::Compute) + .set_attr("FGradient", + ElemwiseGradUseIn{"_backward_npi_rbitwise_right_shift_scalar"}) + .add_argument("data", "NDArray-or-Symbol", "source input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_bitwise_right_shift) + .set_num_inputs(3) + .set_num_outputs(2) + .set_attr("TIsBackward", true) + .set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 1}}; + }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FCompute", + BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_backward_npi_bitwise_right_shift_scalar) + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rbitwise_right_shift_scalar) + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FCompute", + BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cu new file mode 100644 index 000000000000..5a2159fa3b0d --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_thi.cu @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_broadcast_op_extended_thi.cu + * \brief GPU Implementation of extended functions for elementwise binary broadcast operator. (Third + * extended file) + */ + +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_bitwise_left_shift) + .set_attr("FCompute", BinaryBroadcastRTCCompute{"bitwise_left_shift"}); + +NNVM_REGISTER_OP(_npi_bitwise_left_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCCompute{"bitwise_left_shift"}); + +NNVM_REGISTER_OP(_npi_rbitwise_left_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCCompute{"rbitwise_left_shift"}); + +NNVM_REGISTER_OP(_backward_npi_bitwise_left_shift) + .set_attr("FCompute", + BinaryBroadcastRTCBackwardUseIn{"bitwise_left_shift_grad", + "bitwise_left_shift_rgrad"}); + +NNVM_REGISTER_OP(_backward_npi_bitwise_left_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCBackward{"bitwise_left_shift_grad"}); + +NNVM_REGISTER_OP(_backward_npi_rbitwise_left_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCBackward{"rbitwise_left_shift_grad"}); + +NNVM_REGISTER_OP(_npi_bitwise_right_shift) + .set_attr("FCompute", BinaryBroadcastRTCCompute{"bitwise_right_shift"}); + +NNVM_REGISTER_OP(_npi_bitwise_right_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCCompute{"bitwise_right_shift"}); + +NNVM_REGISTER_OP(_npi_rbitwise_right_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCCompute{"rbitwise_right_shift"}); + +NNVM_REGISTER_OP(_backward_npi_bitwise_right_shift) + .set_attr("FCompute", + BinaryBroadcastRTCBackwardUseIn{"bitwise_right_shift_grad", + "bitwise_right_shift_rgrad"}); + +NNVM_REGISTER_OP(_backward_npi_bitwise_right_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCBackward{"bitwise_right_shift_grad"}); + +NNVM_REGISTER_OP(_backward_npi_rbitwise_right_shift_scalar) + .set_attr("FCompute", BinaryScalarRTCBackward{"rbitwise_right_shift_grad"}); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index d36a881cfc32..b15fa36864f9 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -453,6 +453,16 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_left_shift); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rbitwise_left_shift); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::bitwise_left_shift_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::bitwise_left_shift_rgrad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rbitwise_left_shift_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_right_shift); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rbitwise_right_shift); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::bitwise_right_shift_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::bitwise_right_shift_rgrad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rbitwise_right_shift_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::gcd); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ef7bb83c7c69..fbf42c515225 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -216,7 +216,7 @@ void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, if (outputs[0].type_flag_ == mshadow::kBool) { LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } - MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_INT_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 8339f2000153..b4a7498f0eba 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -449,7 +449,7 @@ class ElemwiseBinaryOp : public OpBase { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_INT_TYPE_SWITCH_EXT(outputs[0].type_flag_, DType, { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + DataType::kLanes - 1) / DataType::kLanes; diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index c8edad6b59f3..51b36f7f6bac 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1530,6 +1530,30 @@ def _add_workload_bitwise_xor(): OpArgMngr.add_workload('bitwise_xor', ones, ones) +def _add_workload_bitwise_left_shift(): + for dtype in [np.int8, np.int32, np.int64]: + twenty = np.array([20], dtype=dtype) + three = np.array([3], dtype=dtype) + OpArgMngr.add_workload('bitwise_left_shift', twenty, three) + OpArgMngr.add_workload('bitwise_left_shift', twenty, three) + OpArgMngr.add_workload('bitwise_left_shift', twenty, three) + OpArgMngr.add_workload('bitwise_left_shift', twenty, three) + OpArgMngr.add_workload('bitwise_left_shift', np.array([9223372036854775807], np.int64), np.array([1], np.int64)) + OpArgMngr.add_workload('bitwise_left_shift', np.array([-9223372036854775808], np.int64), np.array([1], np.int64)) + + +def _add_workload_bitwise_right_shift(): + for dtype in [np.int8, np.int32, np.int64]: + twenty = np.array([20], dtype=dtype) + three = np.array([3], dtype=dtype) + OpArgMngr.add_workload('bitwise_right_shift', twenty, three) + OpArgMngr.add_workload('bitwise_right_shift', twenty, three) + OpArgMngr.add_workload('bitwise_right_shift', twenty, three) + OpArgMngr.add_workload('bitwise_right_shift', twenty, three) + OpArgMngr.add_workload('bitwise_right_shift', np.array([9223372036854775807], np.int64), np.array([1], np.int64)) + OpArgMngr.add_workload('bitwise_right_shift', np.array([-9223372036854775808], np.int64), np.array([1], np.int64)) + + def _add_workload_ldexp(): OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int8)) OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int8)) @@ -3099,6 +3123,8 @@ def _prepare_workloads(): _add_workload_bitwise_and() _add_workload_bitwise_xor() _add_workload_bitwise_or() + _add_workload_bitwise_left_shift() + _add_workload_bitwise_right_shift() _add_workload_ldexp() _add_workload_logaddexp(array_pool) _add_workload_subtract(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6a2e6acf6183..d689c5d70237 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3054,6 +3054,58 @@ def forward(self, a, *args, **kwargs): check_unary_func(func, shape, low, high) +@use_np +@pytest.mark.parametrize('ndim', [2, 3, 4]) +@pytest.mark.parametrize('func,low,high', [ + ('left_shift', -5, 5), + ('right_shift', -5, 5), +]) +def test_np_bitwise_shift(func, low, high, ndim): + def check_unary_func(func, shape, low, high): + class TestUnary(HybridBlock): + def __init__(self, func): + super(TestUnary, self).__init__() + self._func = func + + def forward(self, a, b, *args, **kwargs): + return getattr(np, self._func)(a, b) + + np_func = getattr(onp, func) + mx_func = TestUnary("bitwise_" + func) + np_test_data1 = onp.random.randint(low, high, shape).astype(onp.int64) + np_test_data2 = onp.random.randint(low + 5, high + 5, shape).astype(onp.int64) + mx_test_data1 = mx.numpy.array(np_test_data1).astype(onp.int64) + mx_test_data2 = mx.numpy.array(np_test_data2).astype(onp.int64) + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + np_out = np_func(np_test_data1, np_test_data2) + with mx.autograd.record(): + y = mx_func(mx_test_data1, mx_test_data2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + if np_out.dtype == np.bool_: + assert y.dtype == np.bool_ + + np_out = getattr(onp, func)(np_test_data1, np_test_data2) + mx_out = getattr(mx.np, "bitwise_" + func)(mx_test_data1, mx_test_data2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, where=False) + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, subok=False) + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, dtype=onp.int8) + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, dtype="abcdefg") + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, casting='safe') + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, casting='mxnet') + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, order='C') + assertRaises(TypeError, getattr(np, "bitwise_" + func), mx_test_data1, mx_test_data2, order='mxnet') + + shape = random.choice([rand_shape_nd(ndim, dim=3), (1, 0, 2)]) + for shape in [rand_shape_nd(ndim, dim=3), (1, 0, 2)]: + check_unary_func(func, shape, low, high) + + @use_np def test_np_binary_funcs(): def check_binary_func(func, lshape, rshape, low, high, lgrads, rgrads=None, alltypes=None):