diff --git a/CMakeLists.txt b/CMakeLists.txt index d6b348aea99f..016cc8ba5b82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -685,10 +685,6 @@ if(MSVC) endif() -add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) -add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) -target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) -target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) set(MXNET_INSTALL_TARGETS mxnet) if(UNIX) string(APPEND CMAKE_CUDA_FLAGS "${CUDA_ARCH_FLAGS_SPACES}") @@ -701,15 +697,8 @@ if(UNIX) target_link_libraries(mxnet PRIVATE ${BEGIN_WHOLE_ARCHIVE} $ ${END_WHOLE_ARCHIVE}) target_link_libraries(mxnet PRIVATE mxnet_static) target_link_libraries(mxnet_static PUBLIC ${CMAKE_DL_LIBS}) - target_compile_options(sample_lib PUBLIC -shared) - target_compile_options(subgraph_lib PUBLIC -shared) set_target_properties(mxnet_static PROPERTIES OUTPUT_NAME mxnet) elseif(MSVC) - target_compile_options(sample_lib PUBLIC /LD) - target_compile_options(subgraph_lib PUBLIC /LD) - set_target_properties(sample_lib PROPERTIES PREFIX "lib") - set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") - if(USE_CUDA) if(MSVC) if(USE_SPLIT_ARCH_DLL) @@ -762,6 +751,31 @@ elseif(MSVC) endif() +add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) +add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) +target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +if (USE_CUDA) + add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu) + target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +endif() +if(UNIX) + target_compile_options(customop_lib PUBLIC -shared) + target_compile_options(subgraph_lib PUBLIC -shared) + if (USE_CUDA) + target_compile_options(customop_gpu_lib PUBLIC -shared) + endif() +elseif(MSVC) + target_compile_options(customop_lib PUBLIC /LD) + target_compile_options(subgraph_lib PUBLIC /LD) + set_target_properties(customop_lib PROPERTIES PREFIX "lib") + set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") + if (USE_CUDA) + target_compile_options(customop_gpu_lib PUBLIC "$<$:-Xcompiler=-fPIC>") + set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib") + endif() +endif() + if(USE_DIST_KVSTORE) add_subdirectory("3rdparty/ps-lite") add_definitions(-DMXNET_USE_DIST_KVSTORE) diff --git a/Makefile b/Makefile index 16d7d2393736..49c84c55fcfe 100644 --- a/Makefile +++ b/Makefile @@ -457,7 +457,7 @@ endif .PHONY: clean all extra-packages test lint clean_all rcpplint rcppexport roxygen\ cython2 cython3 cython cyclean -all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib subgraph_lib +all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) OBJ = $(patsubst %.cc, build/%.o, $(SRC)) @@ -664,11 +664,19 @@ cpplint: pylint: python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet -# sample lib for MXNet extension dynamically loading custom operator -sample_lib: - $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet +# MXNet extension dynamically loading libraries +EXT_LIBS = custom_op_lib subgraph_lib +ifeq ($(USE_CUDA), 1) + EXT_LIBS += custom_op_gpu_lib +endif +extension_libs: $(EXT_LIBS) + +custom_op_lib: + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o build/libcustomop_lib.so -I include/mxnet +custom_op_gpu_lib: + $(NVCC) -shared -std=c++11 -Xcompiler -fPIC example/extensions/lib_custom_op/relu_lib.cu -o build/libcustomop_gpu_lib.so -I include/mxnet subgraph_lib: - $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o libsubgraph_lib.so -I include/mxnet + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o build/libsubgraph_lib.so -I include/mxnet # Cython build cython: @@ -734,7 +742,6 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - - $(RM) libsample_lib.so libsubgraph_lib.so $(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) $(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS)) else @@ -746,7 +753,6 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - - $(RM) libsample_lib.so libsubgraph_lib.so endif clean_all: clean diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index c697e1e58788..2f469b934d1c 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -23,23 +23,23 @@ utils = load('ci/Jenkinsfile_utils.groovy') // mxnet libraries -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // Python wheels mx_pip = 'build/*.whl' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' +mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. -mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' +mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' -mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*' // Python unittest for CPU diff --git a/example/extensions/lib_custom_op/Makefile b/example/extensions/lib_custom_op/Makefile index 090d17d98a22..edd753b0759c 100644 --- a/example/extensions/lib_custom_op/Makefile +++ b/example/extensions/lib_custom_op/Makefile @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. -all: gemm_lib +all: gemm_lib relu_lib gemm_lib: g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet +relu_lib: + nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet + clean: - rm -rf libgemm_lib.so + rm -rf libgemm_lib.so librelu_lib.so diff --git a/example/extensions/lib_custom_op/gemm_lib.cc b/example/extensions/lib_custom_op/gemm_lib.cc index 3835207e2a16..daeac337f4d6 100644 --- a/example/extensions/lib_custom_op/gemm_lib.cc +++ b/example/extensions/lib_custom_op/gemm_lib.cc @@ -103,7 +103,7 @@ MXReturnValue backward(std::map attrs, unsigned m = inputs[2].shape[1]; // allocate temporary workspace memory through resource manager // for multiple arrays better to request a big memory pool - void *workspace = res.alloc((k*n + m*k) * sizeof(float)); + void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float)); float *At = static_cast(workspace); float *Bt = static_cast(workspace) + (k*n); @@ -167,8 +167,8 @@ MXReturnValue inferShape(std::map attrs, } REGISTER_OP(my_gemm) -.setForward(forward) -.setBackward(backward) +.setForward(forward, "cpu") +.setBackward(backward, "cpu") .setParseAttrs(parseAttrs) .setInferType(inferType) .setInferShape(inferShape); @@ -182,8 +182,7 @@ class MyStatefulGemm : public CustomStatefulOp { MXReturnValue Forward(std::vector inputs, std::vector outputs, OpResource op_res) { - ++count; - std::cout << "Info: keyword + number of forward: " << count << std::endl; + std::cout << "Info: keyword + number of forward: " << ++count << std::endl; std::map attrs; return forward(attrs, inputs, outputs, op_res); } @@ -203,9 +202,9 @@ class MyStatefulGemm : public CustomStatefulOp { MXReturnValue createOpState(std::map attrs, CustomStatefulOp** op_inst) { - int count = 0; - if (attrs.count("test_kw") > 0) - count = std::stoi(attrs["test_kw"]); + // testing passing of keyword arguments + int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; + // creating stateful operator instance *op_inst = new MyStatefulGemm(count); std::cout << "Info: stateful operator created" << std::endl; return MX_SUCCESS; @@ -222,7 +221,7 @@ REGISTER_OP(state_gemm) .setInferType(inferType) .setInferShape(inferShape) .setMutateInputs(mutateInputs) -.setCreateOpState(createOpState); +.setCreateOpState(createOpState, "cpu"); MXReturnValue initialize(int version) { if (version >= 10400) { diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu new file mode 100644 index 000000000000..3beb68c20fa7 --- /dev/null +++ b/example/extensions/lib_custom_op/relu_lib.cu @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file relu_lib.cu + * \brief simple custom relu operator implemented using CUDA function + */ + +#include +#include "lib_api.h" + +__global__ void relu_gpu_forward(float *out, float *in, int64_t N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) + out[tid] = in[tid] > 0 ? in[tid] : 0; +} + +__global__ void relu_gpu_backward(float *ingrad, float *outgrad, float *indata, int64_t N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) + ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0; +} + +MXReturnValue forwardCPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* in_data = inputs[0].data(); + float* out_data = outputs[0].data(); + for (int i=0; i 0 ? in_data[i] : 0; + } + return MX_SUCCESS; +} + +MXReturnValue backwardCPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* out_grad = inputs[0].data(); + float* in_data = inputs[1].data(); + float* in_grad = outputs[0].data(); + for (int i=0; i 0 ? 1 * out_grad[i] : 0; + } + return MX_SUCCESS; +} + +MXReturnValue forwardGPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* in_data = inputs[0].data(); + float* out_data = outputs[0].data(); + + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs[0].size(); + int block = 256; + int grid = (N + (block - 1)) / block; + relu_gpu_forward<<>>(out_data, in_data, N); + + return MX_SUCCESS; +} + +MXReturnValue backwardGPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* out_grad = inputs[0].data(); + float* in_data = inputs[1].data(); + float* in_grad = outputs[0].data(); + + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs[0].size(); + int block = 256; + int grid = (N + (block - 1)) / block; + relu_gpu_backward<<>>(in_grad, out_grad, in_data, N); + + return MX_SUCCESS; +} + +MXReturnValue parseAttrs(std::map attrs, int* num_in, int* num_out) { + *num_in = 1; + *num_out = 1; + return MX_SUCCESS; +} + +MXReturnValue inferType(std::map attrs, + std::vector &intypes, + std::vector &outtypes) { + outtypes[0] = intypes[0]; + return MX_SUCCESS; +} + +MXReturnValue inferShape(std::map attrs, + std::vector> &inshapes, + std::vector> &outshapes) { + outshapes[0] = inshapes[0]; + return MX_SUCCESS; +} + +REGISTER_OP(my_relu) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape) +.setForward(forwardCPU, "cpu") +.setForward(forwardGPU, "gpu") +.setBackward(backwardCPU, "cpu") +.setBackward(backwardGPU, "gpu"); + +class MyStatefulReluCPU : public CustomStatefulOp { +public: + explicit MyStatefulReluCPU() {} + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return forwardCPU(attrs, inputs, outputs, op_res); + } + MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return backwardCPU(attrs, inputs, outputs, op_res); + } + ~MyStatefulReluCPU() {} +}; + +class MyStatefulReluGPU : public CustomStatefulOp { +public: + explicit MyStatefulReluGPU() {} + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return forwardGPU(attrs, inputs, outputs, op_res); + } + MXReturnValue Backward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + std::map attrs; + return backwardGPU(attrs, inputs, outputs, op_res); + } + ~MyStatefulReluGPU() {} +}; + +MXReturnValue createOpStateCPU(std::map attrs, + CustomStatefulOp** op_inst) { + *op_inst = new MyStatefulReluCPU(); + return MX_SUCCESS; +} + +MXReturnValue createOpStateGPU(std::map attrs, + CustomStatefulOp** op_inst) { + *op_inst = new MyStatefulReluGPU(); + return MX_SUCCESS; +} + +REGISTER_OP(my_state_relu) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape) +.setCreateOpState(createOpStateCPU, "cpu") +.setCreateOpState(createOpStateGPU, "gpu"); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_custom_op/test_relu.py b/example/extensions/lib_custom_op/test_relu.py new file mode 100644 index 000000000000..ce2b2fe99cf0 --- /dev/null +++ b/example/extensions/lib_custom_op/test_relu.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks dynamic loading of custom library into MXNet +# and checks end to end compute of a simple 2D gemm custom op + +import mxnet as mx +import os +import time + +#load library +if (os.name=='posix'): + path = os.path.abspath('librelu_lib.so') + mx.library.load(path) + +a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu()) +b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu()) + +print("--------start ndarray compute---------") +print(mx.nd.my_relu(a)) +print(mx.nd.my_relu(b)) +print(mx.nd.my_state_relu(a)) +print(mx.nd.my_state_relu(b)) + +print("--------start symbolic compute--------") +c = mx.sym.Variable('c') +d = mx.sym.Variable('d') +e = mx.sym.my_relu(c) +base = mx.sym.relu(d) +in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())] +in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())] +exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad) +exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base) +out = exe.forward() +out_base = exe_base.forward() +print(out) +print(out_base) + +print("--------start backward compute--------") +out_grad = mx.nd.ones((2,2), ctx=mx.gpu()) +exe.backward([out_grad]) +exe_base.backward([out_grad]) +print(in_grad) +print(in_grad_base) + +print("--------start testing larger ndarray---------") +a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu()) +b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu()) +t1 = time.time() +r1 = mx.nd.my_relu(a) +mx.nd.waitall() +t2 = time.time() +r2 = mx.nd.my_relu(b) +mx.nd.waitall() +t3 = time.time() +r3 = mx.nd.relu(b) +mx.nd.waitall() +t4 = time.time() +print("CPU running time:") +print(t2 - t1) +print("GPU running time:") +print(t3 - t2) +print("Baseline GPU running time:") +print(t4 - t3) diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index 3ebdfc138a79..0727eb786ad8 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector inputs, // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute log operator @@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector inputs, // get input tensor based on node ID inputs from data storage MXTensor &input = data[node_inputs.list[0].list[0].num]; // create temporary storage - MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0); + MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, {"cpu", 0}); // save allocated ptr to free later to_free.push_back(tmp.data_ptr); // execute exp operator @@ -172,7 +172,7 @@ MXReturnValue createOpState(std::map attrs, REGISTER_OP(_custom_subgraph_op) .setIsSubgraphOp() -.setCreateOpState(createOpState); +.setCreateOpState(createOpState, "cpu"); const std::vector op_names({"exp","log"}); diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index cc0ec0f938af..21f5cea125e4 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -39,7 +39,18 @@ #include #include -#define MX_LIBRARY_VERSION 2 +#define MX_LIBRARY_VERSION 3 + +/*! + * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple + * times may lead to undefined behaviour, so we need to set symbol visibility to hidden + * see https://labjack.com/news/simple-cpp-symbol-visibility-demo for details + */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + #define PRIVATE_SYMBOL +#else + #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden"))) +#endif /* * Import from DLPack /~https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h @@ -203,6 +214,16 @@ enum MXDType { kUNSET = 100, }; +/*! + * \brief Context info passing from MXNet OpContext + * dev_type is string repr of supported context, currently only "cpu" and "gpu" + * dev_id is the device index where the tensor locates + */ +typedef struct { + std::string dev_type; + int dev_id; +} MXContext; + enum MXReturnValue { MX_FAIL = 0, MX_SUCCESS = 1, @@ -215,13 +236,13 @@ struct MXTensor { MXTensor() : data_ptr(NULL), dtype(kUNSET), verID(0) {} MXTensor(void *data_ptr, const std::vector &shape, MXDType dtype, - size_t vID) - : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID) {} + size_t vID, MXContext mx_ctx) + : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) {} /*! \brief populate internal tensor fields */ - void setTensor(void *dptr, MXDType type, const int64_t* dims, - int ndims, size_t vID) { - data_ptr = dptr; dtype = type; verID = vID; + void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims, + size_t vID, MXContext mx_ctx) { + data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; shape.clear(); for (int j = 0; j < ndims; j++) { shape.push_back(dims[j]); @@ -232,13 +253,28 @@ struct MXTensor { /*! \brief populate DLTensor fields */ void setDLTensor() { dltensor.data = data_ptr; - dltensor.ctx.device_type = kDLCPU; - dltensor.ctx.device_id = 0; dltensor.ndim = shape.size(); dltensor.shape = const_cast(shape.data()); dltensor.strides = NULL; dltensor.byte_offset = 0; dltensor.dtype.lanes = 1; + dltensor.ctx.device_id = ctx.dev_id; + if (ctx.dev_type == "cpu") + dltensor.ctx.device_type = kDLCPU; + else if (ctx.dev_type == "gpu") + dltensor.ctx.device_type = kDLGPU; + else if (ctx.dev_type == "opencl") + dltensor.ctx.device_type = kDLOpenCL; + else if (ctx.dev_type == "vulcan") + dltensor.ctx.device_type = kDLVulkan; + else if (ctx.dev_type == "metal") + dltensor.ctx.device_type = kDLMetal; + else if (ctx.dev_type == "vpi") + dltensor.ctx.device_type = kDLVPI; + else if (ctx.dev_type == "rocm") + dltensor.ctx.device_type = kDLROCM; + else + dltensor.ctx.device_type = kDLExtDev; switch (dtype) { case kFloat32: dltensor.dtype.code = kDLFloat; @@ -295,9 +331,11 @@ struct MXTensor { /*! \brief helper function to compare two MXTensors */ inline bool isSame(const MXTensor &oth) const { return data_ptr == oth.data_ptr && - dtype == oth.dtype && - verID == oth.verID && - shape == oth.shape; + dtype == oth.dtype && + verID == oth.verID && + ctx.dev_type == oth.ctx.dev_type && + ctx.dev_id == oth.ctx.dev_id && + shape == oth.shape; } // data is flatten 1D repr of tensor, elements are in continuous memory @@ -313,31 +351,55 @@ struct MXTensor { // version number updated if the tensor has changed since the last use by custom op size_t verID; + // context of MXTensor representing which device the tensor data is located + MXContext ctx; + // corresponding DLTensor repr of MXTensor // easy way to reuse functions taking DLTensor DLTensor dltensor; }; -/*! - * \brief resource malloc function to allocate memory inside Forward/Backward functions - */ +/*! \brief resource malloc function to allocate memory inside Forward/Backward functions */ typedef void* (*xpu_malloc_t)(void*, int); +#if defined(__NVCC__) + typedef cudaStream_t mx_stream_t; +#else + typedef void* mx_stream_t; +#endif + /*! * \brief provide resource APIs memory allocation mechanism to Forward/Backward functions */ class OpResource { public: - OpResource(xpu_malloc_t cm, void* ca) : cpu_malloc(cm), cpu_alloc(ca) {} + OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp, + xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream) + : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp), + cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream) {} - /*! \brief allocate memory controlled by MXNet */ - void* alloc(int size) { + /*! \brief allocate cpu memory controlled by MXNet */ + void* alloc_cpu(int size) { return cpu_malloc(cpu_alloc, size); } + /*! \brief allocate gpu memory controlled by MXNet */ + void* alloc_gpu(int size) { + return gpu_malloc(gpu_alloc, size); + } + + /*! \brief return the cuda stream object with correct type */ + mx_stream_t get_cuda_stream() { + return static_cast(cuda_stream); + } + private: - xpu_malloc_t cpu_malloc; - void* cpu_alloc; + /*! \brief allocation lambda function */ + xpu_malloc_t cpu_malloc, gpu_malloc; + /*! \brief lambda function to return allocated memory handle */ + void *cpu_alloc, *gpu_alloc; + /*! \brief cuda stream passed from MXNet */ + void *cuda_stream; }; /*! @@ -558,7 +620,7 @@ typedef MXReturnValue (*inferShape_t)(std::map, typedef MXReturnValue (*mutateInputs_t)(std::map, std::vector&); typedef MXReturnValue (*createOpState_t)(std::map, - CustomStatefulOp**); + CustomStatefulOp**); /*! * \brief Class to hold custom operator registration @@ -566,16 +628,17 @@ typedef MXReturnValue (*createOpState_t)(std::map, class CustomOp { public: explicit CustomOp(const char* op_name) : name(op_name), - forward(NULL), backward(NULL), parse_attrs(NULL), infer_type(NULL), - infer_shape(NULL), mutate_inputs(NULL), create_opstate(NULL), - isSGop(false) {} - ~CustomOp() {} - CustomOp& setForward(fcomp_t fcomp) { - forward = fcomp; + parse_attrs(NULL), infer_type(NULL), infer_shape(NULL), mutate_inputs(NULL), isSGop(false) {} + CustomOp& setForward(fcomp_t fcomp, const char* ctx) { + if (forward_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + forward_ctx_map[ctx] = fcomp; return *this; } - CustomOp& setBackward(fcomp_t fcomp) { - backward = fcomp; + CustomOp& setBackward(fcomp_t fgrad, const char* ctx) { + if (backward_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + backward_ctx_map[ctx] = fgrad; return *this; } CustomOp& setParseAttrs(parseAttrs_t func) { @@ -594,26 +657,58 @@ class CustomOp { mutate_inputs = func; return *this; } - CustomOp& setCreateOpState(createOpState_t func) { - create_opstate = func; + CustomOp& setCreateOpState(createOpState_t func, const char* ctx) { + if (create_op_ctx_map.count(ctx) > 0) + raiseDuplicateContextError(); + create_op_ctx_map[ctx] = func; return *this; } CustomOp& setIsSubgraphOp() { isSGop = true; return *this; } + void mapToVector() { + for (auto kv : forward_ctx_map) { + forward_ctx_cstr.push_back(kv.first); + forward_fp.push_back(kv.second); + } + for (auto kv : backward_ctx_map) { + backward_ctx_cstr.push_back(kv.first); + backward_fp.push_back(kv.second); + } + for (auto kv : create_op_ctx_map) { + create_op_ctx_cstr.push_back(kv.first); + create_op_fp.push_back(kv.second); + } + } + ~CustomOp() {} /*! \brief operator name */ const char* name; + /*! \brief operator functions */ - fcomp_t forward; - fcomp_t backward; parseAttrs_t parse_attrs; inferType_t infer_type; inferShape_t infer_shape; mutateInputs_t mutate_inputs; - createOpState_t create_opstate; bool isSGop; + + /*! \brief vector repr of ctx map to be easily loaded from c_api */ + std::vector forward_ctx_cstr, backward_ctx_cstr, create_op_ctx_cstr; + std::vector forward_fp, backward_fp; + std::vector create_op_fp; + + private: + void raiseDuplicateContextError() { + std::string op_name_str(name); + throw std::runtime_error( + "Error! Error! Cannot register multiple functions under same context for operator '" + + op_name_str + "'"); + } + + /*! \brief dedup context maps - static string ctx to custom function */ + std::unordered_map forward_ctx_map, backward_ctx_map; + std::unordered_map create_op_ctx_map; }; /*! \brief Custom Subgraph Create function template */ @@ -673,7 +768,7 @@ class Registry { * \brief get singleton pointer to class * \returns pointer to class */ - static Registry* get() { + static Registry* get() PRIVATE_SYMBOL { static Registry inst; return &inst; } @@ -690,7 +785,7 @@ class Registry { return entries.size(); } T& get(int idx) { - return *(entries[idx]); + return *(entries.at(idx)); } private: @@ -740,68 +835,94 @@ class Registry { typedef int (*opRegSize_t)(void); #define MXLIB_OPREGGET_STR "_opRegGet" -typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*, - parseAttrs_t*, inferType_t*, - inferShape_t*, mutateInputs_t*, - createOpState_t*, int*); +typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop, + const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count, + const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, createOpState_t** create_op_fp, + int* create_op_count, + parseAttrs_t* parse, inferType_t* type, + inferShape_t* shape, mutateInputs_t* mutate); #define MXLIB_OPCALLFREE_STR "_opCallFree" -typedef int (*opCallFree_t)(void*); +typedef int (*opCallFree_t)(void* ptr); #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" -typedef int (*opCallParseAttrs_t)(parseAttrs_t, const char* const*, const char* const*, int, - int*, int*); +typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* keys, + const char* const* vals, int num, + int* num_in, int* num_out); #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" -typedef int (*opCallInferShape_t)(inferShape_t, const char* const*, const char* const*, int, - unsigned int**, int*, int, - unsigned int***, int**, int); +typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* keys, + const char* const* vals, int num, + unsigned int** inshapes, int* indims, int num_in, + unsigned int*** outshapes, int** outdims, int num_out); #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" -typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* const*, int, - int*, int, int*, int); +typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys, + const char* const* vals, int num, + int* intypes, int num_in, int* outtypes, int num_out); #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" -typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int, - const int64_t**, int*, void**, int*, size_t*, int, - const int64_t**, int*, void**, int*, size_t*, int, - xpu_malloc_t, void*); +typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys, + const char* const* vals, int num, + const int64_t** inshapes, int* indims, + void** indata, int* intypes, + size_t* inIDs, const char** indev_type, + int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, + void** outdata, int* outtypes, + size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc, + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream); #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" -typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int, - int**, int*); +typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys, + const char* const* vals, int num, + int** mutate_indices, int* indices_size); #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" -typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const char* const*, int, - void**); +typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys, + const char* const* vals, int num, + void** state_op); #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" -typedef int (*opCallFStatefulComp_t)(int, void*, const int64_t**, int*, void**, int*, size_t*, - int, const int64_t**, int*, void**, int*, size_t*, - int, xpu_malloc_t, void*); +typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op, + const int64_t** inshapes, int* indims, + void** indata, int* intypes, + size_t* inIDs, const char** indev_type, + int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, + void** outdata, int* outtypes, + size_t* outIDs, const char** outdev_type, + int* outdev_id, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc, + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream); #define MXLIB_PARTREGSIZE_STR "_partRegSize" typedef int (*partRegSize_t)(void); #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" -typedef int (*partRegGetCount_t)(int, const char**); +typedef int (*partRegGetCount_t)(int idx, const char** name); #define MXLIB_PARTREGGET_STR "_partRegGet" -typedef void (*partRegGet_t)(int, int, const char**, supportedOps_t*, - acceptSubgraph_t*, const char**); +typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy, + supportedOps_t* supportedOps, acceptSubgraph_t* acceptSubgraph, + const char** op_name); #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" -typedef int (*partCallSupportedOps_t)(supportedOps_t, const char*, int, int *, - const char* const*, const char* const*, int); +typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, + int num_ids, int *ids, const char* const* opt_keys, + const char* const* opt_vals, int num_opts); + #define MXLIB_PARTCALLACCEPTSUBGRAPH_STR "_partCallAcceptSubgraph" -typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, - const char *json, int subgraph_id, - int *accept, const char* const*, - const char* const*, int, - char***, char***, int*); +typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, const char *json, + int subgraph_id, int *accept, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + char*** attr_keys, char*** attr_vals, int *num_attrs); #define MXLIB_INITIALIZE_STR "initialize" -typedef int (*initialize_t)(int); +typedef int (*initialize_t)(int version); #define MXLIB_OPVERSION_STR "_opVersion" typedef int (*opVersion_t)(); @@ -833,20 +954,29 @@ extern "C" { #else void #endif - _opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad, + _opRegGet(int idx, const char** name, int *isSGop, + const char*** forward_ctx, fcomp_t** forward_fp, int* forward_count, + const char*** backward_ctx, fcomp_t** backward_fp, int* backward_count, + const char*** create_op_ctx, createOpState_t** create_op_fp, int* create_op_count, parseAttrs_t* parse, inferType_t* type, - inferShape_t* shape, mutateInputs_t* mutate, - createOpState_t* create_op, int *isSGop) { - CustomOp op = Registry::get()->get(idx); + inferShape_t* shape, mutateInputs_t* mutate) { + CustomOp &op = Registry::get()->get(idx); *name = op.name; - *fcomp = op.forward; - *fgrad = op.backward; *parse = op.parse_attrs; *type = op.infer_type; *shape = op.infer_shape; *mutate = op.mutate_inputs; - *create_op = op.create_opstate; *isSGop = op.isSGop; + op.mapToVector(); + *forward_ctx = op.forward_ctx_cstr.data(); + *forward_fp = op.forward_fp.data(); + *forward_count = op.forward_fp.size(); + *backward_ctx = op.backward_ctx_cstr.data(); + *backward_fp = op.backward_fp.data(); + *backward_count = op.backward_fp.size(); + *create_op_ctx = op.create_op_ctx_cstr.data(); + *create_op_fp = op.create_op_fp.data(); + *create_op_count = op.create_op_fp.size(); } /*! \brief calls free from the external library for library allocated arrays */ @@ -966,13 +1096,13 @@ extern "C" { #else int #endif - _opCallFCompute(fcomp_t fcomp, const char* const* keys, - const char* const* vals, int num, - const int64_t** inshapes, int* indims, - void** indata, int* intypes, size_t* inIDs, int num_in, - const int64_t** outshapes, int* outdims, - void** outdata, int* outtypes, size_t* outIDs, int num_out, - xpu_malloc_t cpu_malloc, void* cpu_alloc) { + _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, int num, + const int64_t** inshapes, int* indims, void** indata, int* intypes, + size_t* inIDs, const char** indev_type, int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, + size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc, + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream) { // create map of attributes from list std::map attrs; for (int i = 0; i < num; i++) { @@ -982,17 +1112,18 @@ extern "C" { // create a vector of tensors for inputs std::vector inputs(num_in); for (int i = 0; i < num_in; i++) { - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]); + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], {indev_type[i], indev_id[i]}); } // create a vector of tensors for outputs std::vector outputs(num_out); for (int i = 0; i < num_out; i++) { outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i]); + outIDs[i], {outdev_type[i], outdev_id[i]}); } - OpResource res(cpu_malloc, cpu_alloc); + OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream); return fcomp(attrs, inputs, outputs, res); } @@ -1004,8 +1135,8 @@ extern "C" { int #endif _opCallMutateInputs(mutateInputs_t mutate, const char* const* keys, - const char* const* vals, int num, - int** mutate_indices, int* indices_size) { + const char* const* vals, int num, + int** mutate_indices, int* indices_size) { // create map of attributes from list std::map attrs; for (int i = 0; i < num; i++) { @@ -1045,6 +1176,7 @@ extern "C" { } // void pointer to hold custom state op instance created in custom library + // eventually state_op pointer is populated by instance from custom library CustomStatefulOp** op_ptr = reinterpret_cast(state_op); return create_op(attrs, op_ptr); } @@ -1056,24 +1188,28 @@ extern "C" { int #endif _opCallFStatefulCompute(int is_forward, void* state_op, - const int64_t** inshapes, int* indims, - void** indata, int* intypes, size_t* inIDs, int num_in, - const int64_t** outshapes, int* outdims, - void** outdata, int* outtypes, size_t* outIDs, int num_out, - xpu_malloc_t cpu_malloc, void* cpu_alloc) { + const int64_t** inshapes, int* indims, void** indata, int* intypes, + size_t* inIDs, const char** indev_type, int* indev_id, int num_in, + const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, + size_t* outIDs, const char** outdev_type, int* outdev_id, int num_out, + xpu_malloc_t cpu_malloc, void* cpu_alloc, + xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream) { // create a vector of tensors for inputs std::vector inputs(num_in); for (int i = 0; i < num_in; i++) { - inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]); + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], + inIDs[i], {indev_type[i], indev_id[i]}); } // create a vector of tensors for outputs std::vector outputs(num_out); for (int i = 0; i < num_out; i++) { outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i]); + outIDs[i], {outdev_type[i], outdev_id[i]}); } - OpResource res(cpu_malloc, cpu_alloc); + + OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, stream); + CustomStatefulOp* op_ptr = reinterpret_cast(state_op); if (is_forward) { return op_ptr->Forward(inputs, outputs, res); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 54c544ae9415..d64231612592 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -99,7 +99,136 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, // NOTE: return value is added in API_END /*! - * \brief Loads dynamic library and initializes it + * \brief Common compute function dispatcher for forward/backward and stateful forward/backward + * state_ptr will be nullptr for regular ops; fcomp_fp is nullptr for stateful ops + */ +void CustomFComputeDispatcher(const std::string op_name, + const opCallFComp_t callFComp, + const fcomp_t fcomp_fp, + const nnvm::NodeAttrs* attrs, + const opCallFStatefulComp_t callFStatefulComp, + int stateful_forward_flag, + const OpStatePtr* state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_data, out_data; + std::vector in_shapes, out_shapes; + std::vector in_dims, out_dims; + std::vector in_types, out_types; + std::vector in_verIDs, out_verIDs; + std::vector in_dev_type, out_dev_type; + std::vector in_dev_id, out_dev_id; + + // convert inputs/outpus NDArray to C types to be passed to lib_api.h + for (size_t i = 0; i < inputs.size(); i++) { + in_data.push_back(inputs[i].data().dptr_); + in_shapes.push_back(inputs[i].shape().data()); + in_dims.push_back(inputs[i].shape().ndim()); + in_types.push_back(inputs[i].dtype()); + in_verIDs.push_back(inputs[i].version()); + const char* ctx_str = inputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + in_dev_type.push_back(ctx_str); + in_dev_id.push_back(inputs[i].ctx().real_dev_id()); + } + + for (size_t i = 0; i < outputs.size(); i++) { + out_data.push_back(outputs[i].data().dptr_); + out_shapes.push_back(outputs[i].shape().data()); + out_dims.push_back(outputs[i].shape().ndim()); + out_types.push_back(outputs[i].dtype()); + out_verIDs.push_back(outputs[i].version()); + const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; + out_dev_type.push_back(ctx_str); + out_dev_id.push_back(outputs[i].ctx().real_dev_id()); + } + + // get memory resource and mxnet backend streams + const Resource &resource = ctx.requested[0]; + mshadow::Stream *cpu_stream = ctx.get_stream(); + mshadow::Stream *gpu_stream = ctx.get_stream(); + + // create lambda that captures stream & resource objects + // this temp workspace holds memory allocated by custom library via OpResource + auto cpu_alloc = [&](int size) { + mshadow::Tensor workspace = + resource.get_space_typed(mshadow::Shape1(size), cpu_stream); + return workspace.dptr_; + }; + auto gpu_alloc = [&](int size) { + mshadow::Tensor workspace = + resource.get_space_typed(mshadow::Shape1(size), gpu_stream); + return workspace.dptr_; + }; + + // create lambda without captures so that we can cast it to function pointer + // lambda with captures cannot be cast to function pointer and pass to lib_api.h + // this needs to be a lambda function so that we can do the decltype cast + typedef decltype(cpu_alloc) alloc_type_cpu; + auto cpu_malloc = [](void* _cpu_alloc, int size) { + // cast the void* argument to the type for the cpu_alloc lambda function + alloc_type_cpu* cpualloc = static_cast(_cpu_alloc); + // call cpu_alloc to actually allocate memory and return the pointer + return static_cast((*cpualloc)(size)); + }; + typedef decltype(gpu_alloc) alloc_type_gpu; + auto gpu_malloc = [](void* _gpu_alloc, int size) { + alloc_type_gpu* gpualloc = static_cast(_gpu_alloc); + return static_cast((*gpualloc)(size)); + }; + + // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h + void *cuda_stream = nullptr; +#if MXNET_USE_CUDA + if (inputs[0].ctx().dev_mask() == Context::kGPU) { + cuda_stream = static_cast(gpu_stream->stream_); + } +#endif + + CHECK((fcomp_fp != nullptr && state_ptr == nullptr) + || (fcomp_fp == nullptr && state_ptr != nullptr)) + << "Can only register either regular op or stateful op for '" << op_name << "'"; + + if (fcomp_fp != nullptr) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto kv : attrs->dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + // call fcompute function + CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), + in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), + in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), + out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), out_data.size(), + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream)) + << "Error calling FCompute for custom operator '" << op_name << "'"; + } + + if (state_ptr != nullptr) { + // retrieve op state object created from CreateOpState + CustomStatefulOpWrapper& op = state_ptr->get_state(); + CustomStatefulOp* state_op_inst = op.get_instance(); + CHECK(state_op_inst != nullptr) + << "Error custom stateful operator is null for operator '" << op_name << "'"; + + // call fcompute function + CHECK(callFStatefulComp(stateful_forward_flag, state_op_inst, + in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), + in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), + in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), + out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), + out_data.size(), + cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream)) + << "Error calling FStatefulCompute for custom operator '" << op_name << "'"; + } +} + +/*! + * \brief Loads dynamic custom library and initializes it * \param path library path */ int MXLoadLib(const char *path) { @@ -164,39 +293,60 @@ int MXLoadLib(const char *path) { for (int i = 0; i < numOps; i++) { const char* name; // function pointers holding implementation from custom library - fcomp_t fcomp_fp = nullptr; parseAttrs_t parse_fp = nullptr; inferType_t type_fp = nullptr; inferShape_t shape_fp = nullptr; // optional attributes - fcomp_t fgrad_fp = nullptr; mutateInputs_t mutate_fp = nullptr; - createOpState_t create_opstate_fp = nullptr; bool isSubgraphOp = false; int _isSubgraphOp = 0; - - // get custom operator implemenation from the dynamic library - opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, - &mutate_fp, &create_opstate_fp, &_isSubgraphOp); + // lists of forward and backward function associated with each context + const char **forward_ctx, **backward_ctx, **createop_ctx; + fcomp_t *forward_fcomp, *backward_fcomp; + createOpState_t *createop_fp; + int forward_count, backward_count, createop_count; + + // main function to get custom operator implemenation from the custom library + opRegGet(i, &name, &_isSubgraphOp, + &forward_ctx, &forward_fcomp, &forward_count, + &backward_ctx, &backward_fcomp, &backward_count, + &createop_ctx, &createop_fp, &createop_count, + &parse_fp, &type_fp, &shape_fp, &mutate_fp); + + // construct maps of context to forward/backward custom library function + std::unordered_map forward_ctx_map; + std::unordered_map backward_ctx_map; + std::unordered_map createop_map; + for (int i=0; i < forward_count; i++) { + std::string ctx_str(forward_ctx[i]); + forward_ctx_map[ctx_str] = forward_fcomp[i]; + } + for (int i=0; i < backward_count; i++) { + std::string ctx_str(backward_ctx[i]); + backward_ctx_map[ctx_str] = backward_fcomp[i]; + } + for (int i=0; i < createop_count; i++) { + std::string ctx_str(createop_ctx[i]); + createop_map[ctx_str] = createop_fp[i]; + } // set bool, dont pass bool across ABI boundary isSubgraphOp = _isSubgraphOp; + // validate custom operator functions from the dynamic library if (!isSubgraphOp) { - // validate custom operator functions from the dynamic library CHECK(parse_fp != nullptr) << "Error loading '" << name << "' custom op, ParseAttrs function was not set."; - CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name + CHECK(forward_ctx_map.size() != 0 || createop_map.size() != 0) + << "Error loading '" << name << "' custom op, Forward or CreateOpState function was not set."; - CHECK(type_fp != nullptr) << "Error loading '" << name + CHECK(type_fp != nullptr) << "Error loading '" << name << "' custom op, InferType function was not set."; CHECK(shape_fp != nullptr) << "Error loading '" << name << "' custom op, InferShape function was not set."; } else { - // validate custom operator functions from the dynamic library - CHECK(create_opstate_fp != nullptr) << "Error loading '" << name + CHECK(createop_map.size() != 0) << "Error loading '" << name << "' custom subgraph op, CreateOpState function was not set."; } - LOG(INFO) << "\tOp[" << i << "] " << name; std::string name_str(name); @@ -285,7 +435,7 @@ int MXLoadLib(const char *path) { &num_in, &num_out)) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str << "'"; - return num_in + 2*num_out; + return num_in + 2 * num_out; }; // lambda function to call infer shape @@ -389,95 +539,6 @@ int MXLoadLib(const char *path) { return true; }; - // lambda function to convert from external fcompute to internal MXNet types - auto fcomp_lambda = [=](fcomp_t fcomp_fp, - const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - // convert attributes to vector of char* - std::vector attr_keys, attr_vals; - for (auto kv : attrs.dict) { - attr_keys.push_back(kv.first.c_str()); - attr_vals.push_back(kv.second.c_str()); - } - - std::vector in_data, out_data; - std::vector in_shapes, out_shapes; - std::vector in_dims, out_dims; - std::vector in_types, out_types; - std::vector in_verIDs, out_verIDs; - - // convert input tensors to constituent parts - for (size_t i = 0; i < inputs.size(); i++) { - in_data.push_back(inputs[i].data().dptr_); - in_shapes.push_back(inputs[i].shape().data()); - in_dims.push_back(inputs[i].shape().ndim()); - in_types.push_back(inputs[i].dtype()); - in_verIDs.push_back(inputs[i].version()); - } - - // convert output tensors to constituent parts - for (size_t i = 0; i < outputs.size(); i++) { - out_data.push_back(outputs[i].data().dptr_); - out_shapes.push_back(outputs[i].shape().data()); - out_dims.push_back(outputs[i].shape().ndim()); - out_types.push_back(outputs[i].dtype()); - out_verIDs.push_back(outputs[i].version()); - } - - // get memory resource - const Resource &resource = ctx.requested[0]; - mshadow::Stream *cpu_stream = ctx.get_stream(); - - // create lambda that captures stream & resource objects - // this temp workspace holds memory allocated by custom library via OpResource - auto cpu_alloc = [&](int size) { - mshadow::Tensor workspace = - resource.get_space_typed(mshadow::Shape1(size), cpu_stream); - return workspace.dptr_; - }; - - // create lambda without captures so that we can cast it to function pointer - // this needs to be a lambda function so that we can do the decltype cast - typedef decltype(cpu_alloc) alloc_type; - auto cpu_malloc = [](void* _cpu_alloc, int size) { - // cast the void* argument to the type for the cpu_alloc lambda function - alloc_type* cpualloc = static_cast(_cpu_alloc); - // call cpu_alloc to actually allocate memory and get the pointer - void* ptr = (*cpualloc)(size); - return ptr; - }; - - // call fcompute function - CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - in_shapes.data(), in_dims.data(), in_data.data(), - in_types.data(), in_verIDs.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), - out_types.data(), out_verIDs.data(), out_data.size(), - cpu_malloc, &cpu_alloc)) - << "Error calling FCompute for custom operator '" << name_str << "'"; - - // return type void - }; - - auto forward_lambda = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - return fcomp_lambda(fcomp_fp, attrs, ctx, inputs, req, outputs); - }; - - auto backward_lambda = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - return fcomp_lambda(fgrad_fp, attrs, ctx, inputs, req, outputs); - }; - // lambda function to convert from external mutate_inputs to internal MXNet types auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) { // convert attributes to vector of char* @@ -542,9 +603,9 @@ int MXLoadLib(const char *path) { // library author should implement and return a 'state' which points to an instance // in lambda we create OpStatePtr using the returned 'state' auto create_opstate = [=] (const NodeAttrs& attrs, - Context ctx, - const std::vector& in_shapes, - const std::vector& in_types) { + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; for (auto kv : attrs.dict) { @@ -563,101 +624,31 @@ int MXLoadLib(const char *path) { } // create a pointer to hold custom op state object + // only create one stateful op depending on passing context + // user can add new supported context and call to custom library void* state_op_inst = nullptr; - CHECK(callCreateOpState(create_opstate_fp, attr_keys.data(), attr_vals.data(), - attr_keys.size(), &state_op_inst)) - << "Error calling CreateOpState for custom operator '" << name_str << "'"; + if (ctx.dev_mask() == Context::kCPU) { + CHECK(createop_map.count("cpu") > 0) + << "CPU CreateOpState not implemented for '" << name_str << "'"; + CHECK(callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(), + attr_keys.size(), &state_op_inst)) + << "Error calling CreateOpState CPU for custom operator '" << name_str << "'"; + } else if (ctx.dev_mask() == Context::kGPU) { + CHECK(createop_map.count("gpu") > 0) + << "GPU CreateOpState not implemented for '" << name_str << "'"; + CHECK(callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(), + attr_keys.size(), &state_op_inst)) + << "Error calling CreateOpState GPU for custom operator '" << name_str << "'"; + } CHECK(state_op_inst != nullptr) - << "Error custom library failed to create stateful operator '" << name_str << "'"; + << "Error custom library failed to create stateful operator '" << name_str << "'"; CustomStatefulOp* state_op = reinterpret_cast(state_op_inst); return OpStatePtr::Create(state_op); }; - // stateful forward and backward - auto fstateful_lambda = [=](bool is_forward, - const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector in_data, out_data; - std::vector in_shapes, out_shapes; - std::vector in_dims, out_dims; - std::vector in_types, out_types; - std::vector in_verIDs, out_verIDs; - - // convert input tensors to constituent parts - for (size_t i = 0; i < inputs.size(); i++) { - in_data.push_back(inputs[i].data().dptr_); - in_shapes.push_back(inputs[i].shape().data()); - in_dims.push_back(inputs[i].shape().ndim()); - in_types.push_back(inputs[i].dtype()); - in_verIDs.push_back(inputs[i].version()); - } - - // convert output tensors to constituent parts - for (size_t i = 0; i < outputs.size(); i++) { - out_data.push_back(outputs[i].data().dptr_); - out_shapes.push_back(outputs[i].shape().data()); - out_dims.push_back(outputs[i].shape().ndim()); - out_types.push_back(outputs[i].dtype()); - out_verIDs.push_back(outputs[i].version()); - } - - // get memory resource - const Resource &resource = ctx.requested[0]; - mshadow::Stream *cpu_stream = ctx.get_stream(); - - // create lambda that captures stream & resource objects - // this temp workspace holds memory allocated by custom library via OpResource - auto cpu_alloc = [&](int size) { - mshadow::Tensor data = - resource.get_space_typed(mshadow::Shape1(size), cpu_stream); - return data.dptr_; - }; - - // create lambda without captures so that we can cast it to function pointer - // this needs to be a lambda function so that we can do the decltype cast - typedef decltype(cpu_alloc) alloc_type; - auto cpu_malloc = [](void* _cpu_alloc, int size) { - // cast the void* argument to the type for the cpu_alloc lambda function - alloc_type* cpualloc = static_cast(_cpu_alloc); - // call cpu_alloc to actually allocate memory and get the pointer - void* ptr = (*cpualloc)(size); - return ptr; - }; - - // retrieve op state object created from CreateOpState - CustomStatefulOpWrapper& op = state_ptr.get_state(); - CustomStatefulOp* state_op_inst = op.get_instance(); - CHECK(state_op_inst != nullptr) - << "Error MXNet cannot load custom stateful operator'" << name_str << "'"; - - // call fcompute function - CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(), - in_data.data(), in_types.data(), in_verIDs.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), - out_verIDs.data(), out_data.size(), cpu_malloc, &cpu_alloc)) - << "Error calling FStatefulCompute for custom operator '" << name_str << "'"; - }; - - auto fstateful_forward = [=](const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - fstateful_lambda(true, state_ptr, ctx, inputs, req, outputs); - }; - - auto fstateful_backward = [=](const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - fstateful_lambda(false, state_ptr, ctx, inputs, req, outputs); - }; + /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */ // check if operator is already registered const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); @@ -685,10 +676,8 @@ int MXLoadLib(const char *path) { using namespace mxnet::op; regOp.set_num_inputs(DefaultSubgraphOpNumInputs); regOp.set_num_outputs(DefaultSubgraphOpNumOutputs); - regOp.set_attr("FInferType", - DefaultSubgraphOpType, plevel); - regOp.set_attr("FInferShape", - DefaultSubgraphOpShape, plevel); + regOp.set_attr("FInferType", DefaultSubgraphOpType, plevel); + regOp.set_attr("FInferShape", DefaultSubgraphOpShape, plevel); regOp.set_attr("FInferStorageType", DefaultSubgraphOpStorageType, plevel); regOp.set_attr("FResourceRequest", @@ -696,17 +685,47 @@ int MXLoadLib(const char *path) { regOp.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs, plevel); } - // optionally add stateful forward - if (create_opstate_fp != nullptr) { + if (createop_map.size() != 0) { regOp.set_attr("FCreateOpState", create_opstate, plevel); - regOp.set_attr("FStatefulComputeEx", - fstateful_forward, plevel); + auto fstate_forward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, + callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs); + }; + regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); + regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); } else { - regOp.set_attr("FComputeEx", forward_lambda, plevel); + if (forward_ctx_map.count("cpu") > 0) { + fcomp_t fcomp_cpu = forward_ctx_map.at("cpu"); + auto forward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_cpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs); + }; + regOp.set_attr("FComputeEx", forward_cpu_lambda, plevel); + } + if (forward_ctx_map.count("gpu") > 0) { + fcomp_t fcomp_gpu = forward_ctx_map.at("gpu"); + auto forward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_gpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs); + }; + regOp.set_attr("FComputeEx", forward_gpu_lambda, plevel); + } } // optionally add fgradient if user specified a function - if (fgrad_fp != nullptr || create_opstate_fp != nullptr) { + if (backward_ctx_map.size() != 0 || createop_map.size() != 0) { regOp.set_attr("FGradient", grad_reg, plevel); std::string grad_name = "_backward_" + name_str; nnvm::Op &gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); @@ -716,12 +735,44 @@ int MXLoadLib(const char *path) { gradOp.set_num_outputs(num_inputs); gradOp.set_attr("FInferStorageType", infer_storage_type, plevel); gradOp.set_attr("FResourceRequest", resc_req, plevel); - if (create_opstate_fp != nullptr) { + if (createop_map.size() != 0) { gradOp.set_attr("TIsLayerOpBackward", true, plevel); - gradOp.set_attr("FStatefulComputeEx", - fstateful_backward, plevel); + auto fstate_backward = [=](const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, + callFStatefulComp, 0, &state_ptr, + ctx, inputs, req, outputs); + }; + gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); + gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); } else { - gradOp.set_attr("FComputeEx", backward_lambda, plevel); + if (backward_ctx_map.count("cpu") > 0) { + fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu"); + auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs); + }; + gradOp.set_attr("FComputeEx", backward_cpu_lambda, plevel); + } + if (backward_ctx_map.count("gpu") > 0) { + fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu"); + auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs, + nullptr, 0, nullptr, ctx, inputs, req, outputs); + }; + gradOp.set_attr("FComputeEx", backward_gpu_lambda, plevel); + } } } regOp.add_argument("data", "NDArray[]", "Source inputs"); diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py index 63b54d8516b4..39ad9d03d470 100644 --- a/tests/python/unittest/test_extensions.py +++ b/tests/python/unittest/test_extensions.py @@ -23,7 +23,7 @@ import mxnet as mx import numpy as np from mxnet.base import MXNetError -from mxnet.test_utils import download, is_cd_run, assert_almost_equal +from mxnet.test_utils import download, is_cd_run, assert_almost_equal, default_context def check_platform(): return platform.machine() not in ['x86_64', 'AMD64'] @@ -31,8 +31,9 @@ def check_platform(): @unittest.skipIf(check_platform(), "not all machine types supported") @unittest.skipIf(is_cd_run(), "continuous delivery run - ignoring test") def test_custom_op(): + # possible places to find library file if (os.name=='posix'): - lib = 'libsample_lib.so' + lib = 'libcustomop_lib.so' if os.path.exists(lib): fname = lib elif os.path.exists('build/'+lib): @@ -40,27 +41,30 @@ def test_custom_op(): else: raise MXNetError("library %s not found " % lib) elif (os.name=='nt'): - lib = 'libsample_lib.dll' + lib = 'libcustomop_lib.dll' if os.path.exists('windows_package\\lib\\'+lib): fname = 'windows_package\\lib\\'+lib else: raise MXNetError("library %s not found " % lib) fname = os.path.abspath(fname) + # load the library containing gemm custom operators mx.library.load(fname) - # test simple 2D gemm custom op loaded from sample library + # test symbol 2D gemm custom operators s = mx.sym.Variable('s') t = mx.sym.Variable('t') c = mx.sym.my_gemm(s,t) d = mx.sym.state_gemm(s,t) - base = mx.sym.linalg.gemm2(s,t) # baseline + # baseline gemm from MXNet + base = mx.sym.linalg.gemm2(s,t) + # get some random input matrices dim_n, dim_k, dim_m = tuple(np.random.randint(1, 5, size=3)) - mat1 = mx.nd.random.uniform(-10, 10, shape=(dim_n, dim_k), ctx=mx.cpu()) mat2 = mx.nd.random.uniform(-10, 10, shape=(dim_k, dim_m), ctx=mx.cpu()) + # intermediate ndarrays to be populated by gradient compute in_grad1 = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] in_grad2 = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] in_grad_base = [mx.nd.empty((dim_n,dim_k),ctx=mx.cpu()),mx.nd.empty((dim_k,dim_m),ctx=mx.cpu())] @@ -71,17 +75,21 @@ def test_custom_op(): out1 = exe1.forward() out2 = exe2.forward() - out2 = exe2.forward() # stateful + # test stateful operator by calling it multiple times + out2 = exe2.forward() out_base = exe_base.forward() + # check that forward compute matches one executed by MXNet assert_almost_equal(out_base[0].asnumpy(), out1[0].asnumpy(), rtol=1e-3, atol=1e-3) assert_almost_equal(out_base[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3) + # random output grad ndarray for gradient update out_grad = mx.nd.ones((dim_n, dim_m), ctx=mx.cpu()) exe1.backward([out_grad]) exe2.backward([out_grad]) exe_base.backward([out_grad]) + # check that gradient compute matches one executed by MXNet assert_almost_equal(in_grad_base[0].asnumpy(), in_grad1[0].asnumpy(), rtol=1e-3, atol=1e-3) assert_almost_equal(in_grad_base[0].asnumpy(), in_grad2[0].asnumpy(), rtol=1e-3, atol=1e-3) @@ -148,3 +156,47 @@ def test_subgraph(): out3 = exe3.forward() # check that result matches one executed by MXNet assert_almost_equal(out[0].asnumpy(), out3[0].asnumpy(), rtol=1e-3, atol=1e-3) + +@unittest.skipIf(check_platform(), "not all machine types supported") +@unittest.skipIf(is_cd_run(), "continuous delivery run - ignoring test") +@unittest.skipIf(default_context().device_type == 'cpu', "ignoring custom_op_gpu test on cpu run") +def test_custom_op_gpu(): + # possible places to find library file + if (os.name=='posix'): + lib = 'libcustomop_gpu_lib.so' + if os.path.exists(lib): + fname = lib + elif os.path.exists('build/'+lib): + fname = 'build/'+lib + else: + raise MXNetError("library %s not found " % lib) + elif (os.name=='nt'): + lib = 'libcustomop_gpu_lib.dll' + if os.path.exists('windows_package\\lib\\'+lib): + fname = 'windows_package\\lib\\'+lib + else: + raise MXNetError("library %s not found " % lib) + + fname = os.path.abspath(fname) + # load the library containing gemm custom operators + mx.library.load(fname) + + # test symbol custom relu operator in gpu + b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu()) + c = mx.sym.Variable('c') + d = mx.sym.Variable('d') + e = mx.sym.my_relu(c) + base = mx.sym.relu(d) + in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())] + in_grad_base = [mx.nd.empty((2,2), ctx=mx.gpu())] + exe = e.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad) + exe_base = base.bind(ctx=mx.gpu(), args={'d':b}, args_grad=in_grad_base) + out = exe.forward() + out_base = exe_base.forward() + assert_almost_equal(out_base[0].asnumpy(), out[0].asnumpy(), rtol=1e-3, atol=1e-3) + + # test backward + out_grad = mx.nd.ones((2,2), ctx=mx.gpu()) + exe.backward([out_grad]) + exe_base.backward([out_grad]) + assert_almost_equal(in_grad_base[0].asnumpy(), in_grad[0].asnumpy(), rtol=1e-3, atol=1e-3)