diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 16c579fefa32..df06b5c19641 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1582,8 +1582,14 @@ static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; void NDArray::Save(dmlc::Stream *strm) const { // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_shape()) - << "Saving ndarray within the scope of np_shape is not supported."; + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(storage_type(), kDefaultStorage) + << "only allow serializing ndarray of default storage type within the scope of np_shape"; + CHECK_NE(shape_.Size(), 0U) + << "serializing zero-size ndarray within the scope of np_shape is not supported"; + CHECK_NE(shape_.ndim(), 0) + << "serializing scalar ndarray within the scope of np_shape is not supported"; + } // write magic number to mark this version // for storage type strm->Write(NDARRAY_V2_MAGIC); @@ -1701,9 +1707,6 @@ bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { } bool NDArray::Load(dmlc::Stream *strm) { - // TODO(junwu): Support this after NumPy operators are merged - CHECK(!Imperative::Get()->is_np_shape()) - << "Loading ndarray within the scope of np_shape is not supported."; uint32_t magic; if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; if (magic != NDARRAY_V2_MAGIC) { @@ -1724,6 +1727,15 @@ bool NDArray::Load(dmlc::Stream *strm) { // load shape mxnet::TShape shape; if (!shape.Load(strm)) return false; + // TODO(junwu): Support this after NumPy operators are merged + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(stype, kDefaultStorage) + << "only allow deserializing ndarray of default storage type within the scope of np_shape"; + CHECK_NE(shape.Size(), 0U) + << "deserializing zero-size ndarray within the scope of np_shape is not supported"; + CHECK_NE(shape.ndim(), 0) + << "deserializing scalar ndarray within the scope of np_shape is not supported"; + } if (shape.ndim() == 0) { *this = NDArray(); return true; } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8b2a270a34a2..fd23548c64e7 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1703,16 +1703,25 @@ def test_zero_from_numpy(): @with_seed() def test_save_load_zero_size_ndarrays(): - shapes = [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)] - array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] - array_list = [mx.nd.array(arr) for arr in array_list] - with TemporaryDirectory() as work_dir: - fname = os.path.join(work_dir, 'dataset') - mx.nd.save(fname, array_list) - array_list_loaded = mx.nd.load(fname) - assert len(array_list) == len(array_list_loaded) - for a1, a2 in zip(array_list, array_list_loaded): - assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + def check_save_load(is_np_shape, shapes, throw_exception): + with mx.np_shape(is_np_shape): + array_list = [np.random.randint(0, 10, size=shape) for shape in shapes] + array_list = [mx.nd.array(arr) for arr in array_list] + with TemporaryDirectory() as work_dir: + fname = os.path.join(work_dir, 'dataset') + if throw_exception: + assert_exception(mx.nd.save, mx.MXNetError, fname, array_list) + else: + mx.nd.save(fname, array_list) + array_list_loaded = mx.nd.load(fname) + assert len(array_list) == len(array_list_loaded) + for a1, a2 in zip(array_list, array_list_loaded): + assert np.array_equal(a1.asnumpy(), a2.asnumpy()) + + check_save_load(False, [(2, 0, 1), (0,), (0, 4), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False) # legacy mode + check_save_load(True, [(2, 1), (3, 5)], False) # np_shape semantics, no zero-size, should succeed + check_save_load(True, [(2, 1), (3, 0)], True) # np_shape semantics, zero-size, should fail + check_save_load(True, [(2, 1), ()], True) # np_shape semantics, scalar tensor, should fail if __name__ == '__main__':