From a1e8f5f2427c572f0ff1b8b6b1d81e7e0044d194 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 25 Mar 2019 21:08:41 -0700 Subject: [PATCH] Fix a bug to pass the test in test_contrib_rnn (#14520) * fix. * remove type conversion. * remove type cast. --- src/common/utils.h | 1 - src/ndarray/ndarray_function.cc | 6 ++++-- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/common/utils.h b/src/common/utils.h index f3df2e15ec32..4843d7e06b7b 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -752,7 +752,6 @@ inline void ConvertToNumpyShape(mxnet::TShape* shape) { *shape = mxnet::TShape(); // unknown shape ndim = -1 } else { for (int j = 0; j < shape->ndim(); ++j) { - CHECK_GE((*shape)[j], 0) << "Legacy shape cannot have dim size < 0"; if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown (*shape)[j] = -1; // unknown dim size = -1 } diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index a613d5a3decc..8f72bc259afc 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -210,8 +210,6 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, Kernel::Launch(s, out_data.Size(), out_data.dptr()); for (size_t i = 0; i < nds.size(); ++i) { const NDArray& nd = nds[i]; - const nnvm::dim_t num_rows = nd.shape()[0]; - const nnvm::dim_t num_cols = nd.shape()[1]; const TBlob& nd_data = nd.data(); if (i == 0) { @@ -234,6 +232,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, case kCSRStorage: { const TBlob& nd_indices = nd.aux_data(csr::kIdx); const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); + const nnvm::dim_t num_rows = nd.shape()[0]; + const nnvm::dim_t num_cols = nd.shape()[1]; MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type if (nd.storage_initialized()) { @@ -248,6 +248,8 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, } case kRowSparseStorage: { const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx); + const nnvm::dim_t num_rows = nd.shape()[0]; + const nnvm::dim_t num_cols = nd.shape()[1]; MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type if (nd.storage_initialized()) { const nnvm::dim_t nz_rows = nd_indices.Size(); diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index a460e33fa548..3da3f23d7683 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -464,7 +464,7 @@ mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc p mkldnn_memory_format_t format); inline bool same_shape(const mxnet::TShape &shape, const mkldnn_dims_t dims, int ndims) { - if (shape.ndim() != (size_t)ndims) + if (shape.ndim() != ndims) return false; for (int i = 0; i < ndims; i++) if (shape[i] != dims[i])