diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h index 20b8319ac110..52f2da322e7d 100644 --- a/src/operator/nn/dnnl/dnnl_base-inl.h +++ b/src/operator/nn/dnnl/dnnl_base-inl.h @@ -198,6 +198,7 @@ bool SupportDNNLBatchDot(const std::vector& inputs, const NDArray& outp bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector& inputs); bool SupportDNNLReshape(const NDArray& input, const NDArray& output); bool SupportDNNLStack(const std::vector& inputs); +bool SupportDNNLBinary(const std::vector& inputs); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h b/src/operator/nn/dnnl/dnnl_binary-inl.h new file mode 100644 index 000000000000..c439065b9dc2 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_binary-inl.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_binary-inl.h + * \author: Adam Grabowski, adam.grabowski@intel.com + */ + +#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ +#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include "./dnnl_base-inl.h" +#include "./dnnl_ops-inl.h" +#include + +#include "../../tensor/elemwise_binary_broadcast_op.h" + +namespace mxnet { +namespace op { + +using binary_op_fwd_t = dnnl::binary; +using binary_op_fwd_pd_t = dnnl::binary::primitive_desc; + +class DNNLBinaryOpFwd { + public: + template + static DNNLBinaryOpFwd& GetBinaryOpForward(const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs); + DNNLBinaryOpFwd(const dnnl::algorithm alg, + const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs); + + void Execute(const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + +template +DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs) { + using binary_op_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local binary_op_fwd_map fwds; +#else + static MX_THREAD_LOCAL binary_op_fwd_map fwds; +#endif + OpSignature key; + key.AddSign(static_cast(alg)); + key.AddSign(inputs[0]); + key.AddSign(inputs[1]); + key.AddSign(outputs[0]); + + auto it = fwds.find(key); + if (it == fwds.end()) { + const DNNLBinaryOpFwd fwd(alg, attrs, inputs, outputs); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ diff --git a/src/operator/nn/dnnl/dnnl_binary.cc b/src/operator/nn/dnnl/dnnl_binary.cc new file mode 100644 index 000000000000..25a240d8f917 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_binary.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_binary.cc + * \author: Adam Grabowski, adam.grabowski@intel.com + */ + +#if MXNET_USE_ONEDNN == 1 +#include "./dnnl_binary-inl.h" + +namespace mxnet { +namespace op { + +DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg, + const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs) { + auto src0_desc = inputs[0].GetDNNLData()->get_desc(); + auto src1_desc = inputs[1].GetDNNLData()->get_desc(); + auto dst_desc = outputs[0].GetDNNLData()->get_desc(); + + dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc); + fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); + fwd = std::make_shared(*fwd_pd); +} + +void DNNLBinaryOpFwd::Execute(const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto engine = mxnet::CpuEngine::Get()->get_engine(); + auto src0 = + dnnl::memory(fwd_pd->src0_desc(), engine, reinterpret_cast(inputs[0].data().dptr_)); + auto src1 = + dnnl::memory(fwd_pd->src1_desc(), engine, reinterpret_cast(inputs[1].data().dptr_)); + dnnl_output_t out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]); + + dnnl_args_map_t args = { + {DNNL_ARG_SRC_0, src0}, + {DNNL_ARG_SRC_1, src1}, + {DNNL_ARG_DST, *out_mem.second}, + }; + + DNNLStream::Get()->RegisterPrimArgs(*fwd, args); + CommitOutput(outputs[0], out_mem); + DNNLStream::Get()->Submit(); +} + +bool SupportDNNLBinary(const std::vector& inputs) { + auto dtype = inputs[0].dtype(); + return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 && + inputs[0].shape().ndim() != 0 && inputs[1].shape().ndim() != 0 && + dtype == mshadow::kFloat32 && dtype == inputs[1].dtype(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index fa329bf248d5..baf762a31ea2 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, } } +inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + + return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +void NumpyDivideBroadcastComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +template +void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_ONEDNN == 1 + if (SupportDNNLBinary(inputs)) { + const dnnl::algorithm alg = GetDNNLAlgorithm::dnnl_alg; + DNNLRun(DNNLBinaryOpForward, attrs, ctx, inputs, req, outputs); + return; + } +#endif + using namespace op::mshadow_op; + std::vector in_data = {inputs[0].data(), inputs[1].data()}; + std::vector out_data = {outputs[0].data()}; + if (std::is_same::value) { + NumpyBinaryBroadcastComputeWithBool( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyBinaryBroadcastCompute( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyBinaryBroadcastComputeWithBool( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyDivideBroadcastComputeExCPU(attrs, ctx, in_data, req, out_data); + } +} + #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ diff --git a/src/operator/numpy/np_elemwise_broadcast_op_add.cc b/src/operator/numpy/np_elemwise_broadcast_op_add.cc index 50a79ab5dc2f..f57cfda97075 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_add.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_add.cc @@ -33,6 +33,9 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) op::mshadow_op::plus, op::mshadow_op::mixed_plus, op::mshadow_op::mixed_plus>) + .set_attr("FComputeEx", + NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); NNVM_REGISTER_OP(_backward_npi_broadcast_add) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc index 3e627c8c7e10..5eb51de4c300 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc @@ -33,6 +33,9 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) op::mshadow_op::mul, op::mshadow_op::mixed_mul, op::mshadow_op::mixed_mul>) + .set_attr("FComputeEx", + NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc index 5f3ba7653549..de6c41a9df5e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc @@ -33,6 +33,9 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) op::mshadow_op::minus, op::mshadow_op::mixed_minus, op::mshadow_op::mixed_rminus>) + .set_attr("FComputeEx", + NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); NNVM_REGISTER_OP(_backward_npi_broadcast_sub) diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index 047489f648cc..ef813fec6f9e 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -28,7 +28,7 @@ #include #include "../../common/utils.h" #include "../tensor/elemwise_binary_broadcast_op.h" -#include "../numpy/np_elemwise_broadcast_op.h" +#include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { @@ -328,6 +328,14 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } } +void NumpyDivideBroadcastComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TrueDivideBroadcastCompute(attrs, ctx, inputs, req, outputs); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 639379d36cd0..d0e0f0deb42a 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -79,6 +79,9 @@ NNVM_REGISTER_OP(_npi_true_divide) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", TrueDivideBroadcastCompute) + .set_attr("FComputeEx", + NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 20d874dbd826..c57308481844 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -91,8 +91,13 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs, int& out_stype = out_attrs->at(0); bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { +#if MXNET_USE_ONEDNN == 1 + dispatched = + storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); +#else dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); +#endif } if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) { dispatched = @@ -116,8 +121,13 @@ inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs, int& out_stype = out_attrs->at(0); bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { +#if MXNET_USE_ONEDNN == 1 + dispatched = + storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); +#else dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); +#endif } if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) || (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) { @@ -788,6 +798,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, } } +#if MXNET_USE_ONEDNN == 1 +template +void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +// template struct converting op::mshadow_op to dnnl::algorithm +template +struct GetDNNLAlgorithm {}; +template <> +struct GetDNNLAlgorithm { + static const dnnl::algorithm dnnl_alg = dnnl::algorithm::binary_add; +}; +template <> +struct GetDNNLAlgorithm { + static const dnnl::algorithm dnnl_alg = dnnl::algorithm::binary_sub; +}; +template <> +struct GetDNNLAlgorithm { + static const dnnl::algorithm dnnl_alg = dnnl::algorithm::binary_mul; +}; +template <> +struct GetDNNLAlgorithm { + static const dnnl::algorithm dnnl_alg = dnnl::algorithm::binary_div; +}; +#endif + #define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(2) \ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 9d0f107aa760..3bac3124bac0 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -24,9 +24,78 @@ #include "./elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" #include "./elemwise_binary_broadcast_op.h" +#if MXNET_USE_ONEDNN == 1 +#include "../nn/dnnl/dnnl_binary-inl.h" +#endif namespace mxnet { namespace op { + +#if MXNET_USE_ONEDNN == 1 +template +void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape(), + inputs[1].shape(), + outputs[0].shape(), + &new_lshape, + &new_rshape, + &new_oshape); + std::vector new_inputs; + std::vector new_outputs; + if (ndim) { + new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)}; + new_outputs = {outputs[0].Reshape(new_oshape)}; + } else if (inputs[0].shape().Size() == 1 && inputs[0].shape().Size() == 1) { + // BinaryBroadcastShapeCompact function doesn't reshape shape().Size() == 1 tensors + // into shape (1). It is mandatory for oneDNN primitive to have this reshape done. + mxnet::TShape one_shape = mxnet::TShape(1, 1); + new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)}; + new_outputs = {outputs[0].Reshape(one_shape)}; + } else { + new_inputs = {inputs[0], inputs[1]}; + new_outputs = {outputs[0]}; + } + auto i0 = inputs[0]; + auto i1 = inputs[1]; + + DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward(attrs, new_inputs, new_outputs); + fwd.Execute(new_inputs, req, new_outputs); +} +#endif + +template +static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_ONEDNN == 1 + if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { + if (SupportDNNLBinary(inputs)) { + const dnnl::algorithm alg = GetDNNLAlgorithm::dnnl_alg; + DNNLRun(DNNLBinaryOpForward, attrs, ctx, inputs, req, outputs); + } else { + std::vector in_data = {inputs[0].data(), inputs[1].data()}; + std::vector out_data = {outputs[0].data()}; + BinaryBroadcastCompute(attrs, ctx, in_data, req, out_data); + } + return; + } +#endif // MXNET_USE_ONEDNN == 1 + if (std::is_same::value || + std::is_same::value) { + BinaryBroadcastComputeDenseEx(attrs, ctx, inputs, req, outputs); + } else if (std::is_same::value || + std::is_same::value) { + BinaryBroadcastComputeSparseEx(attrs, ctx, inputs, req, outputs); + } +} + MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_add) MXNET_ADD_SPARSE_OP_ALIAS(broadcast_add) MXNET_ADD_SPARSE_OP_ALIAS(broadcast_plus) @@ -56,8 +125,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeDenseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastAddStorageType) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); @@ -106,8 +174,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeDenseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastAddStorageType) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); @@ -148,8 +215,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeSparseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastMulStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); @@ -189,8 +255,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeSparseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastMulStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index dfb012e4c538..a781c90d6bba 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -927,9 +927,9 @@ def test_sign(): assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() - npout_grad = 0; + npout_grad = 0 exe_test.backward(out_grad) assert_almost_equal(arr_grad, npout_grad) @@ -1076,7 +1076,7 @@ def test_abs(): assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() npout_grad = npout_grad * np.sign(data_tmp) exe_test.backward(out_grad) @@ -1915,7 +1915,11 @@ def gen_broadcast_data(idx): [[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]], [[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]], [[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]], - [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]]]) + [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]], + [[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]], + [[256, 256, 3], [3]], [[5, 4], [4]], + [[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]], + [[15, 3, 5], [3, 1]]]) if idx < binary_op_data_shape.shape[0]: l_shape = binary_op_data_shape[idx][0] r_shape = binary_op_data_shape[idx][1] @@ -1939,7 +1943,7 @@ def gen_broadcast_data(idx): def gen_broadcast_data_int(idx): - d = gen_broadcast_data(idx); + d = gen_broadcast_data(idx) return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] @@ -1951,7 +1955,7 @@ def gen_binary_data(dummy): def gen_binary_data_int(dummy): - d = gen_binary_data(dummy); + d = gen_binary_data(dummy) return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] @@ -2012,10 +2016,16 @@ def reduce_op(shape, x): if shape == x.shape: return x keepdims_shape = list(x.shape) + #calculate difference between output and input ndims + # to include cases where inputs' ndims are not equal + ndim_diff = len(x.shape) - len(shape) + for i in range(ndim_diff): + keepdims_shape[i] = 1 + x = np.sum(x, axis=i).reshape(keepdims_shape) for i in range(len(shape)): - if x.shape[i] != shape[i]: - keepdims_shape[i] = 1 - x = np.sum(x, axis=i).reshape(keepdims_shape) + if x.shape[ndim_diff + i] != shape[i]: + keepdims_shape[ndim_diff + i] = 1 + x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape) return x baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1])