Skip to content

Commit

Permalink
Restore save/load ndarray to 1.4.1 (apache#15073)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce authored and haohuw committed Jun 23, 2019
1 parent 442d466 commit 0d41135
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,9 @@ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

void NDArray::Save(dmlc::Stream *strm) const {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_comp())
<< "Saving ndarray within the scope of np_shape is not supported.";
// write magic number to mark this version
// for storage type
strm->Write(NDARRAY_V2_MAGIC);
Expand Down Expand Up @@ -1698,6 +1701,9 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) {
}

bool NDArray::Load(dmlc::Stream *strm) {
// TODO(junwu): Support this after NumPy operators are merged
CHECK(!Imperative::Get()->is_np_comp())
<< "Loading ndarray within the scope of np_shape is not supported.";
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
Expand All @@ -1718,10 +1724,7 @@ bool NDArray::Load(dmlc::Stream *strm) {
// load shape
mxnet::TShape shape;
if (!shape.Load(strm)) return false;
if (!Imperative::Get()->is_np_comp()) {
common::ConvertToNumpyShape(&shape);
}
if (mxnet::op::shape_is_none(shape)) {
if (shape.ndim() == 0) {
*this = NDArray(); return true;
}

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,20 @@ def test_zero_from_numpy():
assert False


@with_seed()
def test_save_load_zero_size_ndarrays():
shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)]
array_list = [np.random.randint(0, 10, size=shape) for shape in shapes]
array_list = [mx.nd.array(arr) for arr in array_list]
with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'dataset')
mx.nd.save(fname, array_list)
array_list_loaded = mx.nd.load(fname)
assert len(array_list) == len(array_list_loaded)
for a1, a2 in zip(array_list, array_list_loaded):
assert np.array_equal(a1.asnumpy(), a2.asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 0d41135

Please sign in to comment.