Skip to content

Commit

Permalink
fix norm sparse fallback (apache#17149)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Jan 2, 2020
1 parent 80a850d commit 8612372
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_norm_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void L2NormComputeEx<cpu>(const nnvm::NodeAttrs& attrs,
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape();
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1);
if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
param.ord == 2) {
// l2 norm on the entire array
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_norm_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void L2NormComputeEx<gpu>(const nnvm::NodeAttrs& attrs,
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
const NDArrayStorageType istype = inputs[0].storage_type();
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape();
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1);
if ((istype == kRowSparseStorage || istype == kCSRStorage) && axis.ndim() == 0 &&
param.ord == 2) {
// l2 norm on the entire array
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ inline bool LpNormStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode::kFCompute);
}
if (param.ord == 2) {
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape();
const mxnet::TShape axis = param.axis.has_value() ? param.axis.value() : mxnet::TShape(0, -1);
if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
axis.ndim() == 0 && param.ord == 2) {
// l2 norm: rsp/csr, axis = () -> dns
Expand Down

0 comments on commit 8612372

Please sign in to comment.