diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc index b5fc8d15f7ac..d26704c35cf5 100644 --- a/src/executor/tensorrt_pass.cc +++ b/src/executor/tensorrt_pass.cc @@ -324,10 +324,10 @@ nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g, std::unordered_map* const params_map) { auto p = nnvm::Node::Create(); p->attrs.op = nnvm::Op::Get("_trt_op"); - op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map); - p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map; - p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map; - p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph; + op::ONNXParam onnx_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map); + p->attrs.dict["serialized_output_map"] = onnx_param.serialized_output_map; + p->attrs.dict["serialized_input_map"] = onnx_param.serialized_input_map; + p->attrs.dict["serialized_onnx_graph"] = onnx_param.serialized_onnx_graph; if (p->op()->attr_parser != nullptr) { p->op()->attr_parser(&(p->attrs)); } diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h index 58f88b051433..011ffe6b7ddb 100644 --- a/src/operator/contrib/nnvm_to_onnx-inl.h +++ b/src/operator/contrib/nnvm_to_onnx-inl.h @@ -37,7 +37,6 @@ #include #include -#include #include #include @@ -49,13 +48,48 @@ #include #include -#include "./tensorrt-inl.h" #include "../operator_common.h" #include "../../common/utils.h" #include "../../common/serialization.h" namespace mxnet { namespace op { + +namespace nnvm_to_onnx { + enum class TypeIO { Inputs = 0, Outputs = 1 }; + using NameToIdx_t = std::map; + using InferenceTuple_t = std::tuple; + using InferenceMap_t = std::map; +} // namespace nnvm_to_onnx + +struct ONNXParam : public dmlc::Parameter { + std::string serialized_onnx_graph; + std::string serialized_input_map; + std::string serialized_output_map; + nnvm_to_onnx::NameToIdx_t input_map; + nnvm_to_onnx::InferenceMap_t output_map; + ::onnx::ModelProto onnx_pb_graph; + + ONNXParam() {} + + ONNXParam(const ::onnx::ModelProto& onnx_graph, + const nnvm_to_onnx::InferenceMap_t& input_map, + const nnvm_to_onnx::NameToIdx_t& output_map) { + common::Serialize(input_map, &serialized_input_map); + common::Serialize(output_map, &serialized_output_map); + onnx_graph.SerializeToString(&serialized_onnx_graph); + } + +DMLC_DECLARE_PARAMETER(ONNXParam) { + DMLC_DECLARE_FIELD(serialized_onnx_graph) + .describe("Serialized ONNX graph"); + DMLC_DECLARE_FIELD(serialized_input_map) + .describe("Map from inputs to topological order as input."); + DMLC_DECLARE_FIELD(serialized_output_map) + .describe("Map from outputs to order in g.outputs."); + } +}; + namespace nnvm_to_onnx { using namespace nnvm; @@ -76,7 +110,7 @@ void ConvertConstant(GraphProto* const graph_proto, const std::string& node_name, std::unordered_map* const shared_buffer); -void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map, +void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map, GraphProto* const graph_proto, const std::unordered_map::iterator& out_iter, const std::string& node_name, @@ -133,7 +167,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto, const nnvm::IndexedGraph &ig, const array_view &inputs); -TRTParam ConvertNnvmGraphToOnnx( +ONNXParam ConvertNnvmGraphToOnnx( const nnvm::Graph &g, std::unordered_map *const shared_buffer); diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc index 902466614c7c..784384e94e1e 100644 --- a/src/operator/contrib/nnvm_to_onnx.cc +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -47,7 +47,6 @@ #include "../../operator/nn/fully_connected-inl.h" #include "../../operator/nn/pooling-inl.h" #include "../../operator/softmax_output-inl.h" -#include "./tensorrt-inl.h" #if MXNET_USE_TENSORRT_ONNX_CHECKER #include @@ -55,14 +54,17 @@ namespace mxnet { namespace op { + +DMLC_REGISTER_PARAMETER(ONNXParam); + namespace nnvm_to_onnx { -op::TRTParam ConvertNnvmGraphToOnnx( +op::ONNXParam ConvertNnvmGraphToOnnx( const nnvm::Graph& g, std::unordered_map* const shared_buffer) { - op::TRTParam trt_param; - op::tensorrt::NameToIdx_t trt_input_map; - op::tensorrt::InferenceMap_t trt_output_map; + op::ONNXParam onnx_param; + op::nnvm_to_onnx::NameToIdx_t onnx_input_map; + op::nnvm_to_onnx::InferenceMap_t onnx_output_map; const nnvm::IndexedGraph& ig = g.indexed_graph(); const auto& storage_types = g.GetAttr("storage_type"); @@ -105,7 +107,7 @@ op::TRTParam ConvertNnvmGraphToOnnx( current_input++; continue; } - trt_input_map.emplace(node_name, current_input++); + onnx_input_map.emplace(node_name, current_input++); ConvertPlaceholder(node_name, placeholder_shapes, graph_proto); } else { // If it's not a placeholder, then by exclusion it's a constant. @@ -140,23 +142,23 @@ op::TRTParam ConvertNnvmGraphToOnnx( auto out_iter = output_lookup.find(node_name); // We found an output if (out_iter != output_lookup.end()) { - ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g, + ConvertOutput(&onnx_output_map, graph_proto, out_iter, node_name, g, storage_types, dtypes); } // output found } // conversion function exists } // loop over i from 0 to num_nodes - model_proto.SerializeToString(&trt_param.serialized_onnx_graph); - common::Serialize(trt_input_map, - &trt_param.serialized_input_map); - common::Serialize(trt_output_map, - &trt_param.serialized_output_map); + model_proto.SerializeToString(&onnx_param.serialized_onnx_graph); + common::Serialize(onnx_input_map, + &onnx_param.serialized_input_map); + common::Serialize(onnx_output_map, + &onnx_param.serialized_output_map); #if MXNET_USE_TENSORRT_ONNX_CHECKER onnx::checker::check_model(model_proto); #endif // MXNET_USE_TENSORRT_ONNX_CHECKER - return trt_param; + return onnx_param; } void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, @@ -489,7 +491,7 @@ void ConvertConstant( } void ConvertOutput( - op::tensorrt::InferenceMap_t* const trt_output_map, + op::nnvm_to_onnx::InferenceMap_t* const output_map, GraphProto* const graph_proto, const std::unordered_map::iterator& out_iter, const std::string& node_name, const nnvm::Graph& g, @@ -501,10 +503,10 @@ void ConvertOutput( int dtype = dtypes[out_idx]; // This should work with fp16 as well - op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type, + op::nnvm_to_onnx::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type, dtype}; - trt_output_map->emplace(node_name, out_tuple); + output_map->emplace(node_name, out_tuple); auto graph_out = graph_proto->add_output(); auto tensor_type = graph_out->mutable_type()->mutable_tensor_type(); diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h index be335ab1208f..062d22e35795 100644 --- a/src/operator/contrib/tensorrt-inl.h +++ b/src/operator/contrib/tensorrt-inl.h @@ -38,7 +38,6 @@ #include #include -#include #include #include @@ -49,6 +48,7 @@ #include #include +#include "nnvm_to_onnx-inl.h" #include "../operator_common.h" #include "../../common/utils.h" #include "../../common/serialization.h" @@ -60,49 +60,15 @@ namespace mxnet { namespace op { using namespace nnvm; -using namespace ::onnx; using int64 = ::google::protobuf::int64; -namespace tensorrt { - enum class TypeIO { Inputs = 0, Outputs = 1 }; - using NameToIdx_t = std::map; - using InferenceTuple_t = std::tuple; - using InferenceMap_t = std::map; -} // namespace tensorrt using trt_name_to_idx = std::map; -struct TRTParam : public dmlc::Parameter { - std::string serialized_onnx_graph; - std::string serialized_input_map; - std::string serialized_output_map; - tensorrt::NameToIdx_t input_map; - tensorrt::InferenceMap_t output_map; - ::onnx::ModelProto onnx_pb_graph; - - TRTParam() {} - - TRTParam(const ::onnx::ModelProto& onnx_graph, - const tensorrt::InferenceMap_t& input_map, - const tensorrt::NameToIdx_t& output_map) { - common::Serialize(input_map, &serialized_input_map); - common::Serialize(output_map, &serialized_output_map); - onnx_graph.SerializeToString(&serialized_onnx_graph); - } - -DMLC_DECLARE_PARAMETER(TRTParam) { - DMLC_DECLARE_FIELD(serialized_onnx_graph) - .describe("Serialized ONNX graph"); - DMLC_DECLARE_FIELD(serialized_input_map) - .describe("Map from inputs to topological order as input."); - DMLC_DECLARE_FIELD(serialized_output_map) - .describe("Map from outputs to order in g.outputs."); - } -}; struct TRTEngineParam { nvinfer1::IExecutionContext* trt_executor; - std::vector > binding_map; + std::vector > binding_map; }; } // namespace op diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc index 619fe1e2b8f4..88a65fba3ea3 100644 --- a/src/operator/contrib/tensorrt.cc +++ b/src/operator/contrib/tensorrt.cc @@ -44,20 +44,18 @@ namespace mxnet { namespace op { -DMLC_REGISTER_PARAMETER(TRTParam); - OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine, - tensorrt::NameToIdx_t input_map, - tensorrt::NameToIdx_t output_map) { + nnvm_to_onnx::NameToIdx_t input_map, + nnvm_to_onnx::NameToIdx_t output_map) { TRTEngineParam param; for (int b = 0; b < trt_engine->getNbBindings(); ++b) { const std::string& binding_name = trt_engine->getBindingName(b); if (trt_engine->bindingIsInput(b)) { param.binding_map.emplace_back(input_map[binding_name], - tensorrt::TypeIO::Inputs); + nnvm_to_onnx::TypeIO::Inputs); } else { param.binding_map.emplace_back(output_map[binding_name], - tensorrt::TypeIO::Outputs); + nnvm_to_onnx::TypeIO::Outputs); } } param.trt_executor = trt_engine->createExecutionContext(); @@ -67,7 +65,7 @@ OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine, OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/, const std::vector& /*ishape*/, const std::vector& /*itype*/) { - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); ::onnx::ModelProto model_proto; bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph); @@ -82,7 +80,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/, nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx( node_param.serialized_onnx_graph, batch_size, 1 << 30); - tensorrt::NameToIdx_t output_map; + nnvm_to_onnx::NameToIdx_t output_map; for (auto& el : node_param.output_map) { output_map[el.first] = std::get<0>(el.second); } @@ -90,7 +88,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/, } void TRTParamParser(nnvm::NodeAttrs* attrs) { - TRTParam param_; + ONNXParam param_; try { param_.Init(attrs->dict); @@ -114,7 +112,7 @@ void TRTParamParser(nnvm::NodeAttrs* attrs) { inline bool TRTInferShape(const NodeAttrs& attrs, std::vector* /*in_shape*/, std::vector* out_shape) { - const auto &node_param = nnvm::get(attrs.parsed); + const auto &node_param = nnvm::get(attrs.parsed); for (auto& el : node_param.output_map) { (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second); } @@ -131,7 +129,7 @@ inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask inline bool TRTInferType(const NodeAttrs& attrs, std::vector* /*in_dtype*/, std::vector* out_dtype) { - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); for (auto& el : node_param.output_map) { (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second); } @@ -140,7 +138,7 @@ inline bool TRTInferType(const NodeAttrs& attrs, std::vector* /*in_dtype*/, inline std::vector TRTListInputNames(const NodeAttrs& attrs) { std::vector output; - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); output.resize(node_param.input_map.size()); for (auto& el : node_param.input_map) { output[el.second] = el.first; @@ -150,7 +148,7 @@ inline std::vector TRTListInputNames(const NodeAttrs& attrs) { inline std::vector TRTListOutputNames(const NodeAttrs& attrs) { std::vector output; - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); output.resize(node_param.output_map.size()); for (auto& el : node_param.output_map) { output[std::get<0>(el.second)] = el.first; @@ -162,11 +160,11 @@ NNVM_REGISTER_OP(_trt_op) .describe(R"code(TRT operation (one engine) )code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); return node_param.input_map.size(); }) .set_num_outputs([](const NodeAttrs& attrs) { - const auto& node_param = nnvm::get(attrs.parsed); + const auto& node_param = nnvm::get(attrs.parsed); return node_param.output_map.size(); }) .set_attr_parser(TRTParamParser) diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu index 2fe8727b73e4..9a9c3c024366 100644 --- a/src/operator/contrib/tensorrt.cu +++ b/src/operator/contrib/tensorrt.cu @@ -52,7 +52,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx, std::vector bindings; bindings.reserve(param.binding_map.size()); for (auto& p : param.binding_map) { - if (p.second == tensorrt::TypeIO::Inputs) { + if (p.second == nnvm_to_onnx::TypeIO::Inputs) { bindings.emplace_back(inputs[p.first].dptr_); } else { bindings.emplace_back(outputs[p.first].dptr_);