Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
should work
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Apr 18, 2019
1 parent c174b2d commit 5e1a175
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 7 deletions.
17 changes: 17 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,23 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle);

/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \param dlpack the pointer of the input DLManagedTensor
* \param out_handle pointer holder to get pointer of NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle);

/*!
* \brief Delete a dlpack tensor
* \param dlpack the pointer of the input DLManagedTensor
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,19 @@ class NDArray {
*/
static NDArray FromDLPack(const DLManagedTensor* tensor);

/*!
* \brief Create a NDArray backed by a dlpack managed tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \return The created NDArray view.
*/
static NDArray FromDLManagedTensor(const DLManagedTensor* tensor);

/*!
* \brief Update ndarray chunk storage handles using existing ndarray storage handles
* Also update the aux_handle, aux_shapes and aux_types.
Expand Down
15 changes: 8 additions & 7 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,9 +4159,9 @@ class DLManagedTensor(ctypes.Structure):

@DeleterFunc
def dl_managed_tensor_deleter(dl_managed_tensor_handle):
void_p = dl_managed_tensor_handle.contents.manager_ctx
pyobj = ctypes.cast(void_p, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
void_p = dl_managed_tensor_handle.contents.manager_ctx
pyobj = ctypes.cast(void_p, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)


def from_numpy(array):
Expand All @@ -4182,21 +4182,22 @@ def make_dl_tensor(array):
dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)]
# TODO(@junrushao1994): For 0-dim ndarrays, strides and shape will be NULL
dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64)
dl_tensor.strides = array.ctypes.strides_as(ctypes.c_int64)
dl_tensor.strides = None
dl_tensor.byte_offset = 0
return dl_tensor

def make_dl_managed_tensor(array):
# TODO(@junrushao1994): improve error message
c_obj = DLManagedTensor()
c_obj.dl_tensor = make_dl_tensor(array)
c_obj.manager_ctx = make_manager_ctx(array)
c_obj.deleter = dl_managed_tensor_deleter
return c_obj

# TODO(@junrushao1994): improve error message
assert array.flags['C_CONTIGUOUS'], "We only support c-contiguous numpy arrays"
c_obj = make_dl_managed_tensor(array)
address = ctypes.addressof(c_obj)
address = ctypes.cast(address, ctypes.c_void_p)
handle = NDArrayHandle()
check_call(_LIB.MXNDArrayFromDLPack(ctypes.addressof(c_obj), ctypes.byref(handle)))
del c_obj
check_call(_LIB.MXNDArrayFromDLManagedTensor(address, ctypes.byref(handle)))
return NDArray(handle=handle)
8 changes: 8 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,14 @@ int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
API_END();
}

int MXNDArrayFromDLManagedTensor(DLManagedTensorHandle dlpack,
NDArrayHandle *out_handle) {
API_BEGIN();
*out_handle = new NDArray(NDArray::FromDLManagedTensor(
static_cast<DLManagedTensor*>(dlpack)));
API_END();
}

int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
API_BEGIN();
if (dlpack != nullptr) {
Expand Down
12 changes: 12 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,18 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) {
return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
}

NDArray NDArray::FromDLManagedTensor(const DLManagedTensor* tensor) {
const DLTensor &dl_tensor = tensor->dl_tensor;
void (*tensor_deleter)(struct DLManagedTensor * self) = tensor->deleter;
void *manager_ctx = tensor->manager_ctx;
auto deleter = [manager_ctx, tensor_deleter](){
if (tensor_deleter != nullptr) {
tensor_deleter(static_cast<DLManagedTensor*>(manager_ctx));
}
};
return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
}

bool NDArray::fresh_out_grad() const {
if (Imperative::AGInfo::IsNone(*this)) return false;
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,14 @@ def test_ndarray_nan_comparison():
for i in (np.isnan(data1_grad))[1][0].flatten():
assert i == True


def test_zero_from_numpy():
print("test_zero_from_numpy")
np_array = np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.nd.from_numpy(np_array)
print(mx_array)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 5e1a175

Please sign in to comment.