From 38383e7e2cdcb1d7c031580182b262e8fc235e32 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 18 Apr 2019 21:19:08 -0700 Subject: [PATCH] should work --- include/mxnet/c_api.h | 17 ++++++++ include/mxnet/ndarray.h | 13 ++++++ python/mxnet/ndarray/ndarray.py | 57 +++++++++++++-------------- src/c_api/c_api.cc | 8 ++++ src/ndarray/ndarray.cc | 12 ++++++ tests/python/unittest/test_ndarray.py | 8 ++++ 6 files changed, 85 insertions(+), 30 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 0acfde0686d4..d5386165d265 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -823,6 +823,23 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, NDArrayHandle *out_handle); + +/*! +* \brief Create a NDArray backed by a dlpack tensor. +* +* This allows us to create a NDArray using the memory +* allocated by an external deep learning framework +* that is DLPack compatible. +* +* The memory is retained until the NDArray went out of scope. +* +* \param dlpack the pointer of the input DLManagedTensor +* \param out_handle pointer holder to get pointer of NDArray +* \return 0 when success, -1 when failure happens +*/ +MXNET_DLL int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack, + NDArrayHandle *out_handle); + /*! * \brief Delete a dlpack tensor * \param dlpack the pointer of the input DLManagedTensor diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 05d3fa45683e..61a75d5b9831 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -587,6 +587,19 @@ class NDArray { */ static NDArray FromDLPack(const DLManagedTensor* tensor); + /*! + * \brief Create a NDArray backed by a dlpack managed tensor. + * + * This allows us to create a NDArray using the memory + * allocated by an external deep learning framework + * that is DLPack compatible. + * + * The memory is retained until the NDArray went out of scope. + * + * \return The created NDArray view. + */ + static NDArray FromDLManagedTensor(const DLManagedTensor* tensor); + /*! * \brief Update ndarray chunk storage handles using existing ndarray storage handles * Also update the aux_handle, aux_shapes and aux_types. diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 8ddbc6511694..7f8838d400a6 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -4122,28 +4122,28 @@ class DLContext(ctypes.Structure): class DLDataType(ctypes.Structure): - _fields_ = [("type_code", ctypes.c_uint8), - ("bits", ctypes.c_uint8), - ("lanes", ctypes.c_uint16)] - TYPE_MAP = { - "int32": (0, 32, 1), - "int64": (0, 64, 1), - "bool": (1, 1, 1), - "uint32": (1, 32, 1), - "uint64": (1, 64, 1), - "float32": (2, 32, 1), - "float64": (2, 64, 1), - } + _fields_ = [("type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16)] + TYPE_MAP = { + "int32": (0, 32, 1), + "int64": (0, 64, 1), + "bool": (1, 1, 1), + "uint32": (1, 32, 1), + "uint64": (1, 64, 1), + "float32": (2, 32, 1), + "float64": (2, 64, 1), + } class DLTensor(ctypes.Structure): - _fields_ = [("data", ctypes.c_void_p), - ("ctx", DLContext), - ("ndim", ctypes.c_int), - ("dtype", DLDataType), - ("shape", ctypes.POINTER(ctypes.c_int64)), - ("strides", ctypes.POINTER(ctypes.c_int64)), - ("byte_offset", ctypes.c_uint64)] + _fields_ = [("data", ctypes.c_void_p), + ("ctx", DLContext), + ("ndim", ctypes.c_int), + ("dtype", DLDataType), + ("shape", ctypes.POINTER(ctypes.c_int64)), + ("strides", ctypes.POINTER(ctypes.c_int64)), + ("byte_offset", ctypes.c_uint64)] class DLManagedTensor(ctypes.Structure): pass @@ -4152,16 +4152,16 @@ class DLManagedTensor(ctypes.Structure): DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor)) -DLManagedTensor._fields_ = [("dl_tensor", DLTensor), +DLManagedTensor._fields_ = [("dl_tensor", DLTensor), # pylint: disable=protected-access ("manager_ctx", ctypes.c_void_p), ("deleter", DeleterFunc)] @DeleterFunc def dl_managed_tensor_deleter(dl_managed_tensor_handle): - void_p = dl_managed_tensor_handle.contents.manager_ctx - pyobj = ctypes.cast(void_p, ctypes.py_object) - ctypes.pythonapi.Py_DecRef(pyobj) + void_p = dl_managed_tensor_handle.contents.manager_ctx + pyobj = ctypes.cast(void_p, ctypes.py_object) + ctypes.pythonapi.Py_DecRef(pyobj) def from_numpy(array): @@ -4173,21 +4173,17 @@ def make_manager_ctx(obj): return void_p def make_dl_tensor(array): - # You may check array.flags here, e.g. array.flags['C_CONTIGUOUS'] - ndim = array.ndim dl_tensor = DLTensor() dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p) dl_tensor.ctx = DLContext(1, 0) dl_tensor.ndim = array.ndim dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)] - # TODO(@junrushao1994): For 0-dim ndarrays, strides and shape will be NULL dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64) - dl_tensor.strides = array.ctypes.strides_as(ctypes.c_int64) + dl_tensor.strides = None dl_tensor.byte_offset = 0 return dl_tensor def make_dl_managed_tensor(array): - # TODO(@junrushao1994): improve error message c_obj = DLManagedTensor() c_obj.dl_tensor = make_dl_tensor(array) c_obj.manager_ctx = make_manager_ctx(array) @@ -4196,7 +4192,8 @@ def make_dl_managed_tensor(array): assert array.flags['C_CONTIGUOUS'], "We only support c-contiguous numpy arrays" c_obj = make_dl_managed_tensor(array) + address = ctypes.addressof(c_obj) + address = ctypes.cast(address, ctypes.c_void_p) handle = NDArrayHandle() - check_call(_LIB.MXNDArrayFromDLPack(ctypes.addressof(c_obj), ctypes.byref(handle))) - del c_obj + check_call(_LIB.MXNDArrayFromDLManagedTensor(address, ctypes.byref(handle))) return NDArray(handle=handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f549ddd13994..e85ec1c1ee15 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -569,6 +569,14 @@ int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, API_END(); } +int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack, + NDArrayHandle *out_handle) { + API_BEGIN(); + *out_handle = new NDArray(NDArray::FromDLManagedTensor( + static_cast(dlpack))); + API_END(); +} + int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) { API_BEGIN(); if (dlpack != nullptr) { diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index eddfbcff9ce8..d5dfc9897fbf 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -365,6 +365,18 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) { return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter); } +NDArray NDArray::FromDLManagedTensor(const DLManagedTensor* tensor) { + const DLTensor &dl_tensor = tensor->dl_tensor; + void (*tensor_deleter)(struct DLManagedTensor * self) = tensor->deleter; + void *manager_ctx = tensor->manager_ctx; + auto deleter = [manager_ctx, tensor_deleter](){ + if (tensor_deleter != nullptr) { + tensor_deleter(static_cast(manager_ctx)); + } + }; + return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter); +} + bool NDArray::fresh_out_grad() const { if (Imperative::AGInfo::IsNone(*this)) return false; Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 94777677354d..57d0e4320767 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1653,6 +1653,14 @@ def test_ndarray_nan_comparison(): for i in (np.isnan(data1_grad))[1][0].flatten(): assert i == True + +def test_zero_from_numpy(): + print("test_zero_from_numpy") + np_array = np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.nd.from_numpy(np_array) + print(mx_array) + + if __name__ == '__main__': import nose nose.runmodule()