Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Adding sparse support to MXTensor for custom operators #17569

Merged
merged 28 commits into from
Mar 22, 2020
Merged
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a63bae9
Added enum for sparse storage
guanxinq Feb 11, 2020
93cddf4
Add structure for Dense and Sparse
guanxinq Feb 11, 2020
8ccfbd2
redesign the data structure for MXSparse
guanxinq Feb 13, 2020
8c9b358
pull out aux data from sparse NDArray
guanxinq Feb 14, 2020
2bf9200
Added more sparse arguments to API interface
guanxinq Feb 15, 2020
7eba53c
Passed sparse from c_api to lib_api.h and set in MXTensor
guanxinq Feb 17, 2020
3fdf771
Fix indent
guanxinq Feb 17, 2020
a1aa78f
fix segfault
guanxinq Feb 19, 2020
0537deb
Fix NDArray to MXTensor errors
guanxinq Feb 25, 2020
4f44695
Add a sample of sparse(CSR) transpose
guanxinq Feb 25, 2020
ade3e46
Make CSR transpose temporarily work by hardcoding
guanxinq Feb 26, 2020
9a26ac3
Fixed sparse output size(Refined)
guanxinq Mar 2, 2020
041470b
Add tests for symbolic and stateful ops
guanxinq Mar 3, 2020
a3b175b
Added a sample for row sparse transpose
guanxinq Mar 3, 2020
99d00c2
Added real row sparse transpose
guanxinq Mar 3, 2020
60e6753
Fix output size issue by adding lambda for CheckAndAlloc()
guanxinq Mar 10, 2020
3e7f23c
Fix mixed storage formats error
guanxinq Mar 11, 2020
b97bfad
Added infer storage type function
guanxinq Mar 12, 2020
41f0784
resolve comments
guanxinq Mar 13, 2020
bd40098
Set inferSType as optional function
guanxinq Mar 16, 2020
7e95dca
Resolve comments
guanxinq Mar 17, 2020
3f963f5
Add error messages
guanxinq Mar 17, 2020
0eb1de9
Resolve comments
guanxinq Mar 18, 2020
79d7d64
verify transpose ops results
guanxinq Mar 18, 2020
89d638f
Resolved merge conflict
guanxinq Mar 19, 2020
9dcb604
fix sanity check
guanxinq Mar 19, 2020
08faed4
Merge and resolve conflicts
guanxinq Mar 19, 2020
7f39b85
update MX_LIBRARY_VERSION to 5
guanxinq Mar 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ enum MXDType {
kUNSET = 100,
};

/*
* MXTensor storage type.
*/
enum MXStorageType {
// dense
kDefaultStorage = 0,
// row sparse
kRowSparseStorage = 1,
// csr
kCSRStorage = 2,
};

/*!
* \brief Context info passing from MXNet OpContext
* dev_type is string repr of supported context, currently only "cpu" and "gpu"
Expand All @@ -229,20 +241,54 @@ enum MXReturnValue {
MX_SUCCESS = 1,
};

struct ChunkDense {
// Pointer to data.
void *data{nullptr};
// Size of data in bytes.
size_t dataSize{0};
// shape of data.
std::vector<int64_t> shape;
// Context of data.
// MXContext ctx;
};

struct ChunkSparse {
// Pointer to data.
void *data{nullptr};
// Size of data in bytes.
size_t dataSize{0};
// length of data.
int64_t data_lens;

// To store aux data for sparse.
// for row_sparse, aux_data[0] = indices
// for csr, aux_data[0] = indptr, aux_data[1] = indices
std::vector<std::vector<int64_t>> aux_data;

// Lens of the aux_data.
// for row_sparse, aux_lens[0] = len(indices)
// for csr, aux_lens[0] = len(indptr), aux_lens[1] = len(indices)
std::vector<int64_t> aux_lens;
// Context of data.
// MXContext ctx;
guanxinq marked this conversation as resolved.
Show resolved Hide resolved
};

/*!
* \brief Tensor data structure used by custom operator
*/
struct MXTensor {
MXTensor() : data_ptr(NULL), dtype(kUNSET), verID(0) {}
MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {}

// Construtor for dense.
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
size_t vID, MXContext mx_ctx)
: data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) {}
size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
guanxinq marked this conversation as resolved.
Show resolved Hide resolved
: data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), stype(stype) {}

/*! \brief populate internal tensor fields */
// To do: solve for CSR and row sparse.
guanxinq marked this conversation as resolved.
Show resolved Hide resolved
void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
size_t vID, MXContext mx_ctx) {
data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx;
size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage) {
guanxinq marked this conversation as resolved.
Show resolved Hide resolved
data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = stype;
shape.clear();
for (int j = 0; j < ndims; j++) {
shape.push_back(dims[j]);
Expand Down Expand Up @@ -335,11 +381,15 @@ struct MXTensor {
verID == oth.verID &&
ctx.dev_type == oth.ctx.dev_type &&
ctx.dev_id == oth.ctx.dev_id &&
shape == oth.shape;
shape == oth.shape &&
stype == oth.stype;
}

// data is flatten 1D repr of tensor, elements are in continuous memory
// user can access each element using the shape of tensor
/*! \brief get MXTensors storage type*/
inline MXStorageType getStorageType() { return stype; }

// For dense, data_ptr points to ChunkDense.
// For sparse, data_ptr points to ChunkSparse.
void *data_ptr;

// shape is in [2,3,4] format to represent high-dim tensor
Expand All @@ -357,6 +407,9 @@ struct MXTensor {
// corresponding DLTensor repr of MXTensor
// easy way to reuse functions taking DLTensor
DLTensor dltensor;

// storage type
MXStorageType stype;
};

/*! \brief resource malloc function to allocate memory inside Forward/Backward functions */
Expand Down