Skip to content

Commit

Permalink
[Numpy] Misc fix (apache#14612)
Browse files Browse the repository at this point in the history
* [Numpy] Misc Fix

* fix build

* !shape_is_none => shape_is_known

* Address comments

* Fix
  • Loading branch information
junrushao authored and reminisce committed Apr 10, 2019
1 parent 081d6ff commit 536f899
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ class NDArray {
/*! \brief set the shape for ith aux data, and update storage shape if necessary */
inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) {
aux_shapes[i] = shape;
if (storage_shape.ndim() > 0) {
if (storage_shape.ndim() >= 0) {
if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) {
storage_shape[0] = shape[0];
} else if (storage_type == kCSRStorage && i == csr::kIdx) {
Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ int MXPredGetOutputShape(PredictorHandle handle,
<< "Index exceed number of outputs";

const mxnet::TShape& s = p->out_shapes[out_index];
CHECK_GE(s.ndim(), 0);
p->out_shapes_buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data());
*shape_data = p->out_shapes_buffer.data();
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
#include "../operator/operator_common.h"

#ifndef MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_
#define MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_
Expand Down Expand Up @@ -196,7 +197,7 @@ inline void SetShapeType(const Context& ctx,

for (size_t i = 0; i < outputs.size(); ++i) {
NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
if (outputs[i]->is_none() || outputs[i]->shape().ndim() == 0) {
if (outputs[i]->is_none() || mxnet::op::shape_is_none(outputs[i]->shape())) {
if (is_dynamic_shape_existing) {
// once there is dynamic shape somewhere, we could not pre-determine the shape.
*outputs[i] = NDArray(ctx, out_types[i]);
Expand Down
10 changes: 5 additions & 5 deletions src/kvstore/gradient_compression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ int64_t GradientCompression::GetCompressedSize(const int64_t original_size) {

void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
mxnet::NDArray *residual, const int priority) {
CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape";
CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape";
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
Expand Down Expand Up @@ -137,8 +137,8 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t

void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to,
const int priority) {
CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape";
CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
CHECK(shape_is_known(from.shape())) << "source operand has undefined shape";
CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape";
const int a = from.ctx().dev_mask();
const int b = to->ctx().dev_mask();
const float threshold = threshold_;
Expand Down
11 changes: 7 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,8 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_op
CHECK(from.shape() == to.shape())
<< "operands shape mismatch"
<< "from.shape = " << from.shape() << " to.shape=" << to.shape();
CHECK(from.shape().ndim() != 0)
<< "source operands have zero dimension shape";
CHECK(!mxnet::op::shape_is_none(from.shape()))
<< "source operands have undefined shape";
// important: callback must always capture by value
const Context from_ctx = from.ctx();
const int a = from_ctx.dev_mask();
Expand Down Expand Up @@ -1663,7 +1663,7 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
// load shape
mxnet::TShape shape;
if (!LegacyTShapeLoad(strm, &shape, magic)) return false;
if (shape.ndim() == 0) {
if (mxnet::op::shape_is_none(shape)) {
*this = NDArray(); return true;
}
// load context
Expand Down Expand Up @@ -1711,7 +1711,10 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
if (!Imperative::Get()->is_np_comp()) {
common::ConvertToNumpyShape(&shape);
}
if (mxnet::op::shape_is_none(shape)) {
*this = NDArray(); return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ndarray/ndarray_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace ndarray {
struct BinaryBase {
inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) {
CHECK(lshape == rshape) << "operands shape mismatch";
CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape";
CHECK(!mxnet::op::shape_is_none(lshape)) << "source operand have zero dimension shape";
return lshape;
}
};
Expand Down

0 comments on commit 536f899

Please sign in to comment.