diff --git a/example/extensions/lib_custom_op/Makefile b/example/extensions/lib_custom_op/Makefile index edd753b0759c..feded2947ca3 100644 --- a/example/extensions/lib_custom_op/Makefile +++ b/example/extensions/lib_custom_op/Makefile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -all: gemm_lib relu_lib +all: gemm_lib relu_lib transposecsr_lib transposerowsp_lib gemm_lib: g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet @@ -23,5 +23,11 @@ gemm_lib: relu_lib: nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet +transposecsr_lib: + g++ -shared -fPIC -std=c++11 transposecsr_lib.cc -o libtransposecsr_lib.so -I ../../../include/mxnet + +transposerowsp_lib: + g++ -shared -fPIC -std=c++11 transposerowsp_lib.cc -o libtransposerowsp_lib.so -I ../../../include/mxnet + clean: - rm -rf libgemm_lib.so librelu_lib.so + rm -rf libgemm_lib.so librelu_lib.so libtransposecsr_lib.so libtransposerowsp_lib.so diff --git a/example/extensions/lib_custom_op/test_transposecsr.py b/example/extensions/lib_custom_op/test_transposecsr.py new file mode 100644 index 000000000000..37d066a7bec2 --- /dev/null +++ b/example/extensions/lib_custom_op/test_transposecsr.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks dynamic loading of custom library into MXNet +# and checks end to end compute of a simple 2D gemm custom op + +import mxnet as mx +import os + +#load library +if (os.name=='posix'): + path = os.path.abspath('libtransposecsr_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('libtransposecsr_lib.dll') + mx.library.load(path) + +a = mx.nd.array([[1,3,0,2,1],[0,1,0,0,0],[0,2,4,5,3]]) +a = a.tostype('csr') +print("--------Input CSR Array---------") +print("data:", a.data.asnumpy()) +print("indices:", a.indices.asnumpy()) +print("indptr:", a.indptr.asnumpy()) + +print("--------Start NDArray Compute---------") +b = mx.nd.my_transposecsr(a) +print("Compute Results:") +print("data:", b.data.asnumpy()) +print("indices:", b.indices.asnumpy()) +print("indptr:", b.indptr.asnumpy()) + +print("Stateful Compute Result:") +c = mx.nd.my_state_transposecsr(a, test_kw=100) +print("data:", c.data.asnumpy()) +print("indices:", c.indices.asnumpy()) +print("indptr:", c.indptr.asnumpy()) + +print("--------start symbolic compute--------") +d = mx.sym.Variable('d') +e = mx.sym.my_transposecsr(d) +f = mx.sym.my_state_transposecsr(d, test_kw=200) + +exe = e.bind(ctx=mx.cpu(),args={'d':a}) +exe2 = f.bind(ctx=mx.cpu(),args={'d':a}) +out = exe.forward() +print("Compute Results:") +print("data:", out[0].data.asnumpy()) +print("indices:", out[0].indices.asnumpy()) +print("indptr:", out[0].indptr.asnumpy()) + +out2 = exe2.forward() +out2 = exe2.forward() +print("Stateful Compute Result:") +print("data:", out2[0].data.asnumpy()) +print("indices:", out2[0].indices.asnumpy()) +print("indptr:", out2[0].indptr.asnumpy()) + +print("--------Baseline(dense)--------") +print(mx.nd.transpose(a.tostype('default'))) diff --git a/example/extensions/lib_custom_op/test_transposerowsp.py b/example/extensions/lib_custom_op/test_transposerowsp.py new file mode 100644 index 000000000000..cea62ec6e98c --- /dev/null +++ b/example/extensions/lib_custom_op/test_transposerowsp.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks dynamic loading of custom library into MXNet +# and checks end to end compute of a simple 2D gemm custom op + +import mxnet as mx +import os + +#load library +if (os.name=='posix'): + path = os.path.abspath('libtransposerowsp_lib.so') + mx.library.load(path) +elif (os.name=='nt'): + path = os.path.abspath('libtransposerowsp_lib.dll') + mx.library.load(path) + +a = mx.nd.array([[1,2,3],[0,0,0],[4,0,5],[0,0,0],[0,0,0]]) +a = a.tostype('row_sparse') +print("--------Input CSR Array---------") +print("data:", a.data.asnumpy()) +print("indices:", a.indices.asnumpy()) + +print("--------Start NDArray Compute---------") +b = mx.nd.my_transposerowsp(a) +print("Compute Results:") +print("data:", b.data.asnumpy()) +print("indices:", b.indices.asnumpy()) + +print("Stateful Compute Result:") +c = mx.nd.my_state_transposerowsp(a, test_kw=100) +print("data:", c.data.asnumpy()) +print("indices:", c.indices.asnumpy()) + +print("--------start symbolic compute--------") +d = mx.sym.Variable('d') +e = mx.sym.my_transposerowsp(d) +f = mx.sym.my_state_transposerowsp(d, test_kw=200) + +exe = e.bind(ctx=mx.cpu(),args={'d':a}) +exe2 = f.bind(ctx=mx.cpu(),args={'d':a}) +out = exe.forward() +print("Compute Results:") +print("data:", out[0].data.asnumpy()) +print("indices:", out[0].indices.asnumpy()) + +out2 = exe2.forward() +out2 = exe2.forward() +print("Stateful Compute Result:") +print("data:", out2[0].data.asnumpy()) +print("indices:", out2[0].indices.asnumpy()) + +print("--------Baseline(dense)--------") +print(mx.nd.transpose(a.tostype('default'))) diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc b/example/extensions/lib_custom_op/transposecsr_lib.cc new file mode 100644 index 000000000000..0daeb3e9f83e --- /dev/null +++ b/example/extensions/lib_custom_op/transposecsr_lib.cc @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file transsparse_lib.cc + * \brief Sample 2D transpose custom operator. + */ + +#include +#include "lib_api.h" + +void transpose(MXTensor src, MXTensor dst, OpResource res) { + MXSparse* A = src.data(); + MXSparse* B = dst.data(); + std::vector shape = src.shape; + int64_t h = shape[0]; + int64_t w = shape[1]; + if(src.stype == kCSRStorage) { + float *Aval = (float*) (A->data); + // Here we need one more element to help calculate index(line 57). + std::vector rowPtr(w + 2, 0); + // count column + for(int i = 0; i < A->data_len; i++) { + rowPtr[A->indices[i] + 2]++; + } + // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct + // result of transposed rowPtr. + for(int i = 2; i < rowPtr.size(); i++) { + rowPtr[i] += rowPtr[i - 1]; + } + + // Alloc memory for sparse data, where 0 is the index + // of B in output vector. + res.alloc_sparse(B, 0, A->data_len, w + 1); + float *Bval = (float*) (B->data); + for(int i = 0; i < h; i++) { + for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) { + // Helps calculate index and after that rowPtr[0:w+1) stores the + // correct result of transposed rowPtr. + int index = rowPtr[A->indices[j] + 1]++; + Bval[index] = Aval[j]; + B->indices[index] = i; + } + } + memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1)); + } +} + +MXReturnValue forward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + // The data types and storage types of inputs and outputs should be the same. + if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) { + std::cout << "Error! Expected all inputs and outputs to be the same type." + << "Found input storage type:" << inputs[0].stype + << " Found output storage type:" << outputs[0].stype + << " Found input data type:" << inputs[0].dtype + << " Found output data type:" << outputs[0].dtype << std::endl; + return MX_FAIL; + } + + transpose(inputs[0], outputs[0], res); + return MX_SUCCESS; +} + +MXReturnValue backward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + return MX_SUCCESS; +} + +MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { + *num_in = 1; + *num_out = 1; + return MX_SUCCESS; +} + +MXReturnValue inferType(std::map attrs, + std::vector &intypes, + std::vector &outtypes) { + // validate inputs + if (intypes.size() != 1) { + std::cout << "Expected 1 inputs to inferType" << std::endl; + return MX_FAIL; + } + if (intypes[0] != kFloat32) { + std::cout << "Expected input to have float32 type" << std::endl; + return MX_FAIL; + } + + outtypes[0] = intypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferSType(std::map attrs, + std::vector &instypes, + std::vector &outstypes) { + if (instypes[0] != kCSRStorage) { + std::cout << "Expected storage type is kCSRStorage" << std::endl; + return MX_FAIL; + } + outstypes[0] = instypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferShape(std::map attrs, + std::vector> &inshapes, + std::vector> &outshapes) { + // validate inputs + if (inshapes.size() != 1) { + std::cout << "Expected 1 inputs to inferShape" << std::endl; + return MX_FAIL; + } + + outshapes[0].push_back(inshapes[0][1]); + outshapes[0].push_back(inshapes[0][0]); + return MX_SUCCESS; +} + +REGISTER_OP(my_transposecsr) +.setForward(forward, "cpu") +.setBackward(backward, "cpu") +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferSType(inferSType) +.setInferShape(inferShape); + +/* ------------------------------------------------------------------------- */ + +class MyStatefulTransposeCSR : public CustomStatefulOp { + public: + explicit MyStatefulTransposeCSR(int count) : count(count) {} + + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::cout << "Info: keyword + number of forward: " << ++count << std::endl; + std::map attrs; + return forward(attrs, inputs, outputs, op_res); + } + + MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return backward(attrs, inputs, outputs, op_res); + } + + private: + int count; +}; + +MXReturnValue createOpState(std::map attrs, + CustomStatefulOp** op_inst) { + // testing passing of keyword arguments + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + // creating stateful operator instance + *op_inst = new MyStatefulTransposeCSR(count); + std::cout << "Info: stateful operator created" << std::endl; + return MX_SUCCESS; +} + +REGISTER_OP(my_state_transposecsr) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferSType(inferSType) +.setInferShape(inferShape) +.setCreateOpState(createOpState, "cpu"); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc b/example/extensions/lib_custom_op/transposerowsp_lib.cc new file mode 100644 index 000000000000..883d816cfa81 --- /dev/null +++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file transsparse_lib.cc + * \brief Sample 2D transpose custom operator. + */ + +#include +#include "lib_api.h" + +void transpose(MXTensor src, MXTensor dst, OpResource res) { + MXSparse* A = src.data(); + MXSparse* B = dst.data(); + + std::vector shape = src.shape; + int64_t h = shape[0]; + int64_t w = shape[1]; + if(src.stype == kRowSparseStorage) { + // Keys of the map is the row index of transposed tensors. + // Values of the map is the rows which have non-zero elements. + std::map> mp; + float *Aval = (float*) (A->data); + for(int i = 0; i < A->data_len; i++) { + int row = i / w; + int col = i % w; + row = A->indices[row]; + if(Aval[i] != 0) { + if(mp.find(col) == mp.end()) { + mp[col] = std::vector(h, 0); + mp[col][row] = Aval[i]; + } + else { + mp[col][row] = Aval[i]; + } + } + } + + // Alloc memory for output tensors. + res.alloc_sparse(B, 0, mp.size()); + float *Bval = (float*) (B->data); + int didx = 0, iidx = 0; + for(auto i : mp) { + B->indices[iidx++] = i.first; + for(auto j : i.second) { + Bval[didx++] = j; + } + } + } +} + +MXReturnValue forward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + // The data types and storage types of inputs and outputs should be the same. + if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != outputs[0].stype) { + std::cout << "Error! Expected all inputs and outputs to be the same type." + << "Found input storage type:" << inputs[0].stype + << " Found output storage type:" << outputs[0].stype + << " Found input data type:" << inputs[0].dtype + << " Found output data type:" << outputs[0].dtype << std::endl; + return MX_FAIL; + } + transpose(inputs[0], outputs[0], res); + return MX_SUCCESS; +} + +MXReturnValue backward(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + return MX_SUCCESS; +} + +MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { + *num_in = 1; + *num_out = 1; + return MX_SUCCESS; +} + +MXReturnValue inferType(std::map attrs, + std::vector &intypes, + std::vector &outtypes) { + // validate inputs + if (intypes.size() != 1) { + std::cout << "Expected 1 inputs to inferType" << std::endl; + return MX_FAIL; + } + if (intypes[0] != kFloat32) { + std::cout << "Expected input to have float32 type" << std::endl; + return MX_FAIL; + } + + outtypes[0] = intypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferSType(std::map attrs, + std::vector &instypes, + std::vector &outstypes) { + if (instypes[0] != kRowSparseStorage) { + std::cout << "Expected storage type is kRowSparseStorage" << std::endl; + return MX_FAIL; + } + outstypes[0] = instypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferShape(std::map attrs, + std::vector> &inshapes, + std::vector> &outshapes) { + // validate inputs + if (inshapes.size() != 1) { + std::cout << "Expected 1 inputs to inferShape" << std::endl; + return MX_FAIL; + } + + outshapes[0].push_back(inshapes[0][1]); + outshapes[0].push_back(inshapes[0][0]); + return MX_SUCCESS; +} + +REGISTER_OP(my_transposerowsp) +.setForward(forward, "cpu") +.setBackward(backward, "cpu") +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferSType(inferSType) +.setInferShape(inferShape); + +/* ------------------------------------------------------------------------- */ + +class MyStatefulTransposeRowSP : public CustomStatefulOp { + public: + explicit MyStatefulTransposeRowSP(int count) : count(count) {} + + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::cout << "Info: keyword + number of forward: " << ++count << std::endl; + std::map attrs; + return forward(attrs, inputs, outputs, op_res); + } + + MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return backward(attrs, inputs, outputs, op_res); + } + + private: + int count; +}; + +MXReturnValue createOpState(std::map attrs, + CustomStatefulOp** op_inst) { + // testing passing of keyword arguments + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + // creating stateful operator instance + *op_inst = new MyStatefulTransposeRowSP(count); + std::cout << "Info: stateful operator created" << std::endl; + return MX_SUCCESS; +} + +REGISTER_OP(my_state_transposerowsp) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferSType(inferSType) +.setInferShape(inferShape) +.setCreateOpState(createOpState, "cpu"); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index 8c24dd880f72..d821bdb0d1c2 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector inputs, // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute log operator @@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector inputs, // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}, kDefaultStorage); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute exp operator diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 9b32122c7d7a..fd526ee4172f 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -39,7 +39,7 @@ #include #include -#define MX_LIBRARY_VERSION 4 +#define MX_LIBRARY_VERSION 5 /*! * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple @@ -214,6 +214,18 @@ enum MXDType { kUNSET = 100, }; +/* + * MXTensor storage type. + */ +enum MXStorageType { + // dense + kDefaultStorage = 0, + // row sparse + kRowSparseStorage = 1, + // csr + kCSRStorage = 2, +}; + /*! * \brief Context info passing from MXNet OpContext * dev_type is string repr of supported context, currently only "cpu" and "gpu" @@ -229,25 +241,64 @@ enum MXReturnValue { MX_SUCCESS = 1, }; +// For sparse tensors, read/write the data from NDarray via pointers. +struct MXSparse { + // Pointer to data. + void *data{nullptr}; + // length of (non-zero) data. + int64_t data_len; + + // To store aux data for sparse. + // For CSR, indices stores the col index of non-zero elements. + // For row sparse, indices store row index of rows which have non-zero elements. + int64_t* indices; + int64_t indices_len; + + // For CSR, indptr gives the start and end index of data for each row. + // For row sparse, indptr is not used. + int64_t* indptr = nullptr; + int64_t indptr_len; + + void set(void *data_ptr, const int64_t* dims, int ndims, void *idx, + int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) { + data = data_ptr; + // If CSR, num of non-zero elemets is num_idx, + // If row sparse, num of elements is num_idx * width. + data_len = num_idx; + if (!idx_ptr) { + for (int i = 1; i < ndims; ++i) + data_len *= dims[i]; + } + + indices = reinterpret_cast(idx); + indices_len = num_idx; + + if (idx_ptr) { + indptr = reinterpret_cast(idx_ptr); + indptr_len = num_idx_ptr; + } + } +}; + /*! * \brief Tensor data structure used by custom operator */ struct MXTensor { - MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0) {} + MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {} MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape), - dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx) { + dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) { setDLTensor(); } MXTensor(void *data_ptr, const std::vector &shape, MXDType dtype, - size_t vID, MXContext mx_ctx) - : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) { + size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage) + : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), stype(stype) { setDLTensor(); } /*! \brief populate internal tensor fields */ void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims, - size_t vID, MXContext mx_ctx) { - data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; + size_t vID, MXContext mx_ctx, MXStorageType storage_type) { + data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type; shape.clear(); for (int j = 0; j < ndims; j++) { shape.push_back(dims[j]); @@ -340,11 +391,12 @@ struct MXTensor { verID == oth.verID && ctx.dev_type == oth.ctx.dev_type && ctx.dev_id == oth.ctx.dev_id && - shape == oth.shape; + shape == oth.shape && + stype == oth.stype; } - // data is flatten 1D repr of tensor, elements are in continuous memory - // user can access each element using the shape of tensor + // For dense, data_ptr points to data. + // For sparse, data_ptr points to MXSparse. void *data_ptr; // shape is in [2,3,4] format to represent high-dim tensor @@ -362,11 +414,16 @@ struct MXTensor { // corresponding DLTensor repr of MXTensor // easy way to reuse functions taking DLTensor DLTensor dltensor; + + // storage type + MXStorageType stype; }; /*! \brief resource malloc function to allocate memory inside Forward/Backward functions */ typedef void* (*xpu_malloc_t)(void*, int); +typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**); + #if defined(__NVCC__) typedef cudaStream_t mx_stream_t; #else @@ -379,9 +436,11 @@ typedef void* (*xpu_malloc_t)(void*, int); class OpResource { public: OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp, - xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream) + xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream, + sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp) : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp), - cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream) {} + cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream), + sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {} /*! \brief allocate cpu memory controlled by MXNet */ void* alloc_cpu(int size) { @@ -398,6 +457,12 @@ class OpResource { return static_cast(cuda_stream); } + /*! \brief allocate sparse memory controlled by MXNet */ + void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) { + sparse_malloc(sparse_alloc, index, indices_len, indptr_len, + &(sparse->data), &(sparse->indices), &(sparse->indptr)); + } + private: /*! \brief allocation lambda function */ xpu_malloc_t cpu_malloc, gpu_malloc; @@ -405,6 +470,10 @@ class OpResource { void *cpu_alloc, *gpu_alloc; /*! \brief cuda stream passed from MXNet */ void *cuda_stream; + /*! \brief sparse allocation lambda function */ + sparse_malloc_t sparse_malloc; + /*! \brief lambda function to return allocated sparse memory handle */ + void *sparse_alloc; }; /*! @@ -647,6 +716,8 @@ typedef MXReturnValue (*parseAttrs_t)(std::map, int*, int*); typedef MXReturnValue (*inferType_t)(std::map, std::vector&, std::vector&); +typedef MXReturnValue (*inferSType_t)(std::map, + std::vector&, std::vector&); typedef MXReturnValue (*inferShape_t)(std::map, std::vector >&, std::vector >&); @@ -660,9 +731,9 @@ typedef MXReturnValue (*createOpState_t)(std::map, */ class CustomOp { public: - explicit CustomOp(const char* op_name) : - name(op_name), parse_attrs(nullptr), infer_type(nullptr), - infer_shape(nullptr), mutate_inputs(nullptr), isSGop(false) {} + explicit CustomOp(const char* op_name) : name(op_name), + parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL), infer_shape(NULL), + mutate_inputs(NULL), isSGop(false) {} CustomOp& setForward(fcomp_t fcomp, const char* ctx) { if (forward_ctx_map.count(ctx) > 0) raiseDuplicateContextError(); @@ -683,6 +754,10 @@ class CustomOp { infer_type = func; return *this; } + CustomOp& setInferSType(inferSType_t func) { + infer_storage_type = func; + return *this; + } CustomOp& setInferShape(inferShape_t func) { infer_shape = func; return *this; @@ -723,6 +798,7 @@ class CustomOp { /*! \brief operator functions */ parseAttrs_t parse_attrs; inferType_t infer_type; + inferSType_t infer_storage_type; inferShape_t infer_shape; mutateInputs_t mutate_inputs; bool isSGop; @@ -876,7 +952,7 @@ typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop, const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, const char*** create_op_ctx, createOpState_t** create_op_fp, int* create_op_count, - parseAttrs_t* parse, inferType_t* type, + parseAttrs_t* parse, inferType_t* type, inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate); #define MXLIB_OPCALLFREE_STR "_opCallFree" @@ -898,6 +974,11 @@ typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys, const char* const* vals, int num, int* intypes, int num_in, int* outtypes, int num_out); +#define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType" +typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* keys, + const char* const* vals, int num, + int* intypes, int num_in, int* outtypes, int num_out); + #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys, const char* const* vals, int num, @@ -910,7 +991,13 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys, size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream); + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream, + sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, + void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes); #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys, @@ -933,7 +1020,13 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op, size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream); + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream, + sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, + void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes); #define MXLIB_PARTREGSIZE_STR "_partRegSize" typedef int (*partRegSize_t)(void); @@ -1004,12 +1097,13 @@ extern "C" { const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count, const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, const char*** create_op_ctx, createOpState_t** create_op_fp, int* create_op_count, - parseAttrs_t* parse, inferType_t* type, + parseAttrs_t* parse, inferType_t* type, inferSType_t* stype, inferShape_t* shape, mutateInputs_t* mutate) { CustomOp &op = Registry::get()->get(idx); *name = op.name; *parse = op.parse_attrs; *type = op.infer_type; + *stype = op.infer_storage_type; *shape = op.infer_shape; *mutate = op.mutate_inputs; *isSGop = op.isSGop; @@ -1136,6 +1230,43 @@ extern "C" { return retval; } + /*! \brief returns status of calling inferSType function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _opCallInferSType(inferSType_t inferSType, const char* const* keys, + const char* const* vals, int num, + int* instypes, int num_in, int* outstypes, int num_out) { + // create map of attributes from list + std::map attrs; + for (int i = 0; i < num; i++) { + attrs[std::string(keys[i])] = std::string(vals[i]); + } + + // create a vector of types for inputs + std::vector in_stypes(num_in); + for (int i = 0; i < num_in; i++) { + in_stypes[i] = instypes[i]; + } + + // create a vector of types for outputs + std::vector out_stypes(num_out, -1); + + int retval = inferSType(attrs, in_stypes, out_stypes); + + if (!retval) + return retval; + + // copy output storage types + for (int i = 0; i < num_out; i++) { + outstypes[i] = out_stypes[i]; + } + + return retval; + } + /*! \brief returns status of calling Forward/Backward function for operator from library */ #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) __declspec(dllexport) int __cdecl @@ -1148,7 +1279,12 @@ extern "C" { const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream) { + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream, + sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) { // create map of attributes from list std::map attrs; for (int i = 0; i < num; i++) { @@ -1157,20 +1293,59 @@ extern "C" { // create a vector of tensors for inputs std::vector inputs(num_in); + // create a vector for sparse inputs + std::vector in_sparse(num_in); + for (int i = 0; i < num_in; i++) { - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], {indev_type[i], indev_id[i]}); + // Dense representation. + if (instypes[i] == 0) { + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage); + } else { + // Sparse representation. + MXStorageType type; + if (instypes[i] == 1) { + type = kRowSparseStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); + } else { + type = kCSRStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], + in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + } + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], + inshapes[i], indims[i], inIDs[i], {indev_type[i], indev_id[i]}, type); + } } // create a vector of tensors for outputs std::vector outputs(num_out); + std::vector out_sparse(num_out); + for (int i = 0; i < num_out; i++) { - outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], {outdev_type[i], outdev_id[i]}); + // Dense representation. + if (outstypes[i] == 0) { + outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage); + } else { + // Sparse representation. + MXStorageType type; + if (outstypes[i] == 1) { + type = kRowSparseStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], + out_indices[i], out_indices_shapes[i]); + } else { + type = kCSRStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + } + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], + outshapes[i], outdims[i], outIDs[i], {outdev_type[i], + outdev_id[i]}, type); + } } - OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream); - + OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + cuda_stream, sparse_malloc, sparse_alloc); return fcomp(attrs, inputs, outputs, res); } @@ -1239,22 +1414,69 @@ extern "C" { const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc, - xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream) { + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream, + sparse_malloc_t sparse_malloc, void* sparse_alloc, + int* instypes, int* outstypes, void** in_indices, void** out_indices, + void** in_indptr, void** out_indptr, + int64_t* in_indices_shapes, int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) { // create a vector of tensors for inputs std::vector inputs(num_in); + // create a vector for sparse inputs + std::vector in_sparse(num_in); + for (int i = 0; i < num_in; i++) { - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], {indev_type[i], indev_id[i]}); + if (instypes[i] == 0) { + // Dense representation. + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], {indev_type[i], indev_id[i]}, kDefaultStorage); + } else { + // Sparse representation. + MXStorageType type; + if (instypes[i] == 1) { + type = kRowSparseStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); + } else { + type = kCSRStorage; + in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], + in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + } + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (MXDType)intypes[i], + inshapes[i], indims[i], inIDs[i], {indev_type[i], + indev_id[i]}, type); + } } // create a vector of tensors for outputs std::vector outputs(num_out); + // create a vector for sparse outputs + std::vector out_sparse(num_out); + for (int i = 0; i < num_out; i++) { - outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], {outdev_type[i], outdev_id[i]}); + if (outstypes[i] == 0) { + // Dense representation. + outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i], {outdev_type[i], outdev_id[i]}, kDefaultStorage); + } else { + // Sparse representation. + MXStorageType type; + if (outstypes[i] == 1) { + type = kRowSparseStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i]); + } else { + type = kCSRStorage; + out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], + out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + } + outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (MXDType)outtypes[i], + outshapes[i], outdims[i], outIDs[i], {outdev_type[i], + outdev_id[i]}, type); + } } - OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, stream); + OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, + stream, sparse_malloc, sparse_alloc); CustomStatefulOp* op_ptr = reinterpret_cast(state_op); if (is_forward) { diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index db0e2629a5df..fe00a9a0718b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -114,7 +114,7 @@ void CustomFComputeDispatcher(const std::string op_name, const std::vector& req, const std::vector& outputs) { std::vector in_data, out_data; - std::vector in_shapes, out_shapes; + std::vector in_shapes, out_shapes; std::vector in_dims, out_dims; std::vector in_types, out_types; std::vector in_verIDs, out_verIDs; @@ -122,6 +122,13 @@ void CustomFComputeDispatcher(const std::string op_name, std::vector in_dev_id, out_dev_id; std::vector conv_mkl; // converted NDArrays from MKLDNN format + // Extra data for sparse inputs and outputs. + std::vector in_stypes(inputs.size(), 0), out_stypes(outputs.size(), 0); + std::vector in_indices(inputs.size(), nullptr), out_indices(outputs.size(), nullptr); + std::vector in_indptr(inputs.size(), nullptr), out_indptr(outputs.size(), nullptr); + std::vector in_indices_shapes(inputs.size(), 0), out_indices_shapes(outputs.size(), 0); + std::vector in_indptr_shapes(inputs.size(), 0), out_indptr_shapes(outputs.size(), 0); + // convert inputs/outpus NDArray to C types to be passed to lib_api.h for (size_t i = 0; i < inputs.size(); i++) { NDArray const* in_nd = &(inputs[i]); @@ -141,7 +148,19 @@ void CustomFComputeDispatcher(const std::string op_name, in_verIDs.push_back(in_nd->version()); const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; in_dev_type.push_back(ctx_str); + in_dev_id.push_back(in_nd->ctx().real_dev_id()); + if (inputs[i].storage_type() == mxnet::kRowSparseStorage) { + in_stypes[i] = 1; + in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_; + in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size(); + } else if (inputs[i].storage_type() == mxnet::kCSRStorage) { + in_stypes[i] = 2; + in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_; + in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_; + in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size(); + in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size(); + } } for (size_t i = 0; i < outputs.size(); i++) { @@ -153,6 +172,18 @@ void CustomFComputeDispatcher(const std::string op_name, const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; out_dev_type.push_back(ctx_str); out_dev_id.push_back(outputs[i].ctx().real_dev_id()); + + if (outputs[i].storage_type() == mxnet::kRowSparseStorage) { + out_stypes[i] = 1; + out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_; + out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size(); + } else if (outputs[i].storage_type() == mxnet::kCSRStorage) { + out_stypes[i] = 2; + out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_; + out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_; + out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size(); + out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size(); + } } // get memory resource and mxnet backend streams @@ -173,6 +204,24 @@ void CustomFComputeDispatcher(const std::string op_name, return workspace.dptr_; }; + // create lambda that allocates memory for sparse and + // returns allocated arrays for data, indices and indptr. + auto sparse_alloc = [&](int index, int indices_len, int idxptr_len, + void** data, int64_t** indices, int64_t** indptr) { + if (idxptr_len == 0) { + // Row Sparse + outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)}); + *data = outputs[index].data().dptr_; + *indices = reinterpret_cast(outputs[index].aux_data(rowsparse::kIdx).dptr_); + } else { + // CSR + outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len), mshadow::Shape1(indices_len)}); + *data = outputs[index].data().dptr_; + *indices = reinterpret_cast(outputs[index].aux_data(csr::kIdx).dptr_); + *indptr = reinterpret_cast(outputs[index].aux_data(csr::kIndPtr).dptr_); + } + }; + // create lambda without captures so that we can cast it to function pointer // lambda with captures cannot be cast to function pointer and pass to lib_api.h // this needs to be a lambda function so that we can do the decltype cast @@ -189,6 +238,13 @@ void CustomFComputeDispatcher(const std::string op_name, return static_cast((*gpualloc)(size)); }; + typedef decltype(sparse_alloc) alloc_type_sparse; + auto sparse_malloc = [](void* _sparse_alloc, int index, int indices_len, int idxptr_len, + void** data, int64_t** indices, int64_t** indptr) { + alloc_type_sparse* sparsealloc = static_cast(_sparse_alloc); + (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr); + }; + // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h void *cuda_stream = nullptr; #if MXNET_USE_CUDA @@ -208,13 +264,18 @@ void CustomFComputeDispatcher(const std::string op_name, attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } + // call fcompute function CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(), out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream)) + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, + sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), + in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(), + in_indices_shapes.data(), out_indices_shapes.data(), + in_indptr_shapes.data(), out_indptr_shapes.data())) << "Error calling FCompute for custom operator '" << op_name << "'"; } @@ -233,7 +294,12 @@ void CustomFComputeDispatcher(const std::string op_name, out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream)) + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, + sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), + in_indices.data(), out_indices.data(), + in_indptr.data(), out_indptr.data(), + in_indices_shapes.data(), out_indices_shapes.data(), + in_indptr_shapes.data(), out_indptr_shapes.data())) << "Error calling FStatefulCompute for custom operator '" << op_name << "'"; } } @@ -272,6 +338,9 @@ int MXLoadLib(const char *path) { opCallInferType_t callInferType = get_func(lib, const_cast(MXLIB_OPCALLINFERTYPE_STR)); + opCallInferSType_t callInferSType = + get_func(lib, const_cast(MXLIB_OPCALLINFERSTYPE_STR)); + opCallFComp_t callFComp = get_func(lib, const_cast(MXLIB_OPCALLFCOMP_STR)); @@ -306,6 +375,7 @@ int MXLoadLib(const char *path) { // function pointers holding implementation from custom library parseAttrs_t parse_fp = nullptr; inferType_t type_fp = nullptr; + inferSType_t stype_fp = nullptr; inferShape_t shape_fp = nullptr; // optional attributes mutateInputs_t mutate_fp = nullptr; @@ -322,7 +392,7 @@ int MXLoadLib(const char *path) { &forward_ctx, &forward_fcomp, &forward_count, &backward_ctx, &backward_fcomp, &backward_count, &createop_ctx, &createop_fp, &createop_count, - &parse_fp, &type_fp, &shape_fp, &mutate_fp); + &parse_fp, &type_fp, &stype_fp, &shape_fp, &mutate_fp); // construct maps of context to forward/backward custom library function std::unordered_map forward_ctx_map; @@ -583,12 +653,39 @@ int MXLoadLib(const char *path) { DispatchMode* dispatch_mode, std::vector* in_stypes, std::vector* out_stypes) { - // TODO(ziyimu): remove this dense enforce check after supporting sparse tensor - CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage)) - << "Error input tensors are not dense for custom operator '" << name_str << "'"; - // set outputs as dense - return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); + if (stype_fp == nullptr) { + // InferSType is not defineid in customized lib. + CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage)) + << "Error input tensors are not dense for custom operator '" << name_str << "'"; + // set outputs as dense + return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } else { + // InferSType is defined in customized lib. + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + // copy input types from in_stype + std::vector instypes(*in_stypes); + + // output types will be populated by inferType function + std::vector outstypes(out_stypes->size()); + CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + instypes.data(), in_stypes->size(), + outstypes.data(), out_stypes->size())) + << "Error calling InferSType for custom operator '" << name_str << "'"; + + // copy and assign output storage types from custom op to MXNet memory. + for (size_t i = 0; i < out_stypes->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]); + } + // assign dispatch mode + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; + } }; // FGradient register lambda @@ -698,8 +795,8 @@ int MXLoadLib(const char *path) { regOp.set_num_inputs(num_inputs); regOp.set_num_outputs(num_outputs); regOp.set_attr("FInferType", infer_type, plevel); - regOp.set_attr("FInferShape", infer_shape, plevel); regOp.set_attr("FInferStorageType", infer_storage_type, plevel); + regOp.set_attr("FInferShape", infer_shape, plevel); regOp.set_attr("FResourceRequest", resc_req, plevel); // optionally add fmutate inputs if user specified a function if (mutate_fp != nullptr)