Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
should work
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Apr 19, 2019
1 parent c174b2d commit 38383e7
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 30 deletions.
17 changes: 17 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,23 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle);

/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \param dlpack the pointer of the input DLManagedTensor
* \param out_handle pointer holder to get pointer of NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle);

/*!
* \brief Delete a dlpack tensor
* \param dlpack the pointer of the input DLManagedTensor
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,19 @@ class NDArray {
*/
static NDArray FromDLPack(const DLManagedTensor* tensor);

/*!
* \brief Create a NDArray backed by a dlpack managed tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \return The created NDArray view.
*/
static NDArray FromDLManagedTensor(const DLManagedTensor* tensor);

/*!
* \brief Update ndarray chunk storage handles using existing ndarray storage handles
* Also update the aux_handle, aux_shapes and aux_types.
Expand Down
57 changes: 27 additions & 30 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4122,28 +4122,28 @@ class DLContext(ctypes.Structure):


class DLDataType(ctypes.Structure):
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
TYPE_MAP = {
"int32": (0, 32, 1),
"int64": (0, 64, 1),
"bool": (1, 1, 1),
"uint32": (1, 32, 1),
"uint64": (1, 64, 1),
"float32": (2, 32, 1),
"float64": (2, 64, 1),
}
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
TYPE_MAP = {
"int32": (0, 32, 1),
"int64": (0, 64, 1),
"bool": (1, 1, 1),
"uint32": (1, 32, 1),
"uint64": (1, 64, 1),
"float32": (2, 32, 1),
"float64": (2, 64, 1),
}


class DLTensor(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", DLContext),
("ndim", ctypes.c_int),
("dtype", DLDataType),
("shape", ctypes.POINTER(ctypes.c_int64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
("byte_offset", ctypes.c_uint64)]
_fields_ = [("data", ctypes.c_void_p),
("ctx", DLContext),
("ndim", ctypes.c_int),
("dtype", DLDataType),
("shape", ctypes.POINTER(ctypes.c_int64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
("byte_offset", ctypes.c_uint64)]

class DLManagedTensor(ctypes.Structure):
pass
Expand All @@ -4152,16 +4152,16 @@ class DLManagedTensor(ctypes.Structure):
DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor))


DLManagedTensor._fields_ = [("dl_tensor", DLTensor),
DLManagedTensor._fields_ = [("dl_tensor", DLTensor), # pylint: disable=protected-access
("manager_ctx", ctypes.c_void_p),
("deleter", DeleterFunc)]


@DeleterFunc
def dl_managed_tensor_deleter(dl_managed_tensor_handle):
void_p = dl_managed_tensor_handle.contents.manager_ctx
pyobj = ctypes.cast(void_p, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
void_p = dl_managed_tensor_handle.contents.manager_ctx
pyobj = ctypes.cast(void_p, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)


def from_numpy(array):
Expand All @@ -4173,21 +4173,17 @@ def make_manager_ctx(obj):
return void_p

def make_dl_tensor(array):
# You may check array.flags here, e.g. array.flags['C_CONTIGUOUS']
ndim = array.ndim
dl_tensor = DLTensor()
dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p)
dl_tensor.ctx = DLContext(1, 0)
dl_tensor.ndim = array.ndim
dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)]
# TODO(@junrushao1994): For 0-dim ndarrays, strides and shape will be NULL
dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64)
dl_tensor.strides = array.ctypes.strides_as(ctypes.c_int64)
dl_tensor.strides = None
dl_tensor.byte_offset = 0
return dl_tensor

def make_dl_managed_tensor(array):
# TODO(@junrushao1994): improve error message
c_obj = DLManagedTensor()
c_obj.dl_tensor = make_dl_tensor(array)
c_obj.manager_ctx = make_manager_ctx(array)
Expand All @@ -4196,7 +4192,8 @@ def make_dl_managed_tensor(array):

assert array.flags['C_CONTIGUOUS'], "We only support c-contiguous numpy arrays"
c_obj = make_dl_managed_tensor(array)
address = ctypes.addressof(c_obj)
address = ctypes.cast(address, ctypes.c_void_p)
handle = NDArrayHandle()
check_call(_LIB.MXNDArrayFromDLPack(ctypes.addressof(c_obj), ctypes.byref(handle)))
del c_obj
check_call(_LIB.MXNDArrayFromDLManagedTensor(address, ctypes.byref(handle)))
return NDArray(handle=handle)
8 changes: 8 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,14 @@ int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
API_END();
}

int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle) {
API_BEGIN();
*out_handle = new NDArray(NDArray::FromDLManagedTensor(
static_cast<DLManagedTensor*>(dlpack)));
API_END();
}

int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
API_BEGIN();
if (dlpack != nullptr) {
Expand Down
12 changes: 12 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,18 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) {
return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
}

NDArray NDArray::FromDLManagedTensor(const DLManagedTensor* tensor) {
const DLTensor &dl_tensor = tensor->dl_tensor;
void (*tensor_deleter)(struct DLManagedTensor * self) = tensor->deleter;
void *manager_ctx = tensor->manager_ctx;
auto deleter = [manager_ctx, tensor_deleter](){
if (tensor_deleter != nullptr) {
tensor_deleter(static_cast<DLManagedTensor*>(manager_ctx));
}
};
return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
}

bool NDArray::fresh_out_grad() const {
if (Imperative::AGInfo::IsNone(*this)) return false;
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,14 @@ def test_ndarray_nan_comparison():
for i in (np.isnan(data1_grad))[1][0].flatten():
assert i == True


def test_zero_from_numpy():
print("test_zero_from_numpy")
np_array = np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.nd.from_numpy(np_array)
print(mx_array)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 38383e7

Please sign in to comment.