Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conv… #13659

Merged
merged 1 commit into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/executor/tensorrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
std::unordered_map<std::string, NDArray>* const params_map) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_trt_op");
op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map;
p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
op::ONNXParam onnx_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
p->attrs.dict["serialized_output_map"] = onnx_param.serialized_output_map;
p->attrs.dict["serialized_input_map"] = onnx_param.serialized_input_map;
p->attrs.dict["serialized_onnx_graph"] = onnx_param.serialized_onnx_graph;
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
Expand Down
42 changes: 38 additions & 4 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>

#include <algorithm>
Expand All @@ -49,13 +48,48 @@
#include <utility>
#include <string>

#include "./tensorrt-inl.h"
#include "../operator_common.h"
#include "../../common/utils.h"
#include "../../common/serialization.h"

namespace mxnet {
namespace op {

namespace nnvm_to_onnx {
enum class TypeIO { Inputs = 0, Outputs = 1 };
using NameToIdx_t = std::map<std::string, int32_t>;
using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
} // namespace nnvm_to_onnx

struct ONNXParam : public dmlc::Parameter<ONNXParam> {
std::string serialized_onnx_graph;
std::string serialized_input_map;
std::string serialized_output_map;
nnvm_to_onnx::NameToIdx_t input_map;
nnvm_to_onnx::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

ONNXParam() {}

ONNXParam(const ::onnx::ModelProto& onnx_graph,
const nnvm_to_onnx::InferenceMap_t& input_map,
const nnvm_to_onnx::NameToIdx_t& output_map) {
common::Serialize(input_map, &serialized_input_map);
common::Serialize(output_map, &serialized_output_map);
onnx_graph.SerializeToString(&serialized_onnx_graph);
}

DMLC_DECLARE_PARAMETER(ONNXParam) {
DMLC_DECLARE_FIELD(serialized_onnx_graph)
.describe("Serialized ONNX graph");
DMLC_DECLARE_FIELD(serialized_input_map)
.describe("Map from inputs to topological order as input.");
DMLC_DECLARE_FIELD(serialized_output_map)
.describe("Map from outputs to order in g.outputs.");
}
};

namespace nnvm_to_onnx {

using namespace nnvm;
Expand All @@ -76,7 +110,7 @@ void ConvertConstant(GraphProto* const graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);

void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
Expand Down Expand Up @@ -133,7 +167,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);

TRTParam ConvertNnvmGraphToOnnx(
ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);

Expand Down
34 changes: 18 additions & 16 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,24 @@
#include "../../operator/nn/fully_connected-inl.h"
#include "../../operator/nn/pooling-inl.h"
#include "../../operator/softmax_output-inl.h"
#include "./tensorrt-inl.h"

#if MXNET_USE_TENSORRT_ONNX_CHECKER
#include <onnx/checker.h>
#endif // MXNET_USE_TENSORRT_ONNX_CHECKER

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(ONNXParam);

namespace nnvm_to_onnx {

op::TRTParam ConvertNnvmGraphToOnnx(
op::ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;
op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

const nnvm::IndexedGraph& ig = g.indexed_graph();
const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
Expand Down Expand Up @@ -105,7 +107,7 @@ op::TRTParam ConvertNnvmGraphToOnnx(
current_input++;
continue;
}
trt_input_map.emplace(node_name, current_input++);
onnx_input_map.emplace(node_name, current_input++);
ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
} else {
// If it's not a placeholder, then by exclusion it's a constant.
Expand Down Expand Up @@ -140,23 +142,23 @@ op::TRTParam ConvertNnvmGraphToOnnx(
auto out_iter = output_lookup.find(node_name);
// We found an output
if (out_iter != output_lookup.end()) {
ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
ConvertOutput(&onnx_output_map, graph_proto, out_iter, node_name, g,
storage_types, dtypes);
} // output found
} // conversion function exists
} // loop over i from 0 to num_nodes

model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
&trt_param.serialized_input_map);
common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
&trt_param.serialized_output_map);
model_proto.SerializeToString(&onnx_param.serialized_onnx_graph);
common::Serialize<op::nnvm_to_onnx::NameToIdx_t>(onnx_input_map,
&onnx_param.serialized_input_map);
common::Serialize<op::nnvm_to_onnx::InferenceMap_t>(onnx_output_map,
&onnx_param.serialized_output_map);

#if MXNET_USE_TENSORRT_ONNX_CHECKER
onnx::checker::check_model(model_proto);
#endif // MXNET_USE_TENSORRT_ONNX_CHECKER

return trt_param;
return onnx_param;
}

void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
Expand Down Expand Up @@ -489,7 +491,7 @@ void ConvertConstant(
}

void ConvertOutput(
op::tensorrt::InferenceMap_t* const trt_output_map,
op::nnvm_to_onnx::InferenceMap_t* const output_map,
GraphProto* const graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name, const nnvm::Graph& g,
Expand All @@ -501,10 +503,10 @@ void ConvertOutput(
int dtype = dtypes[out_idx];

// This should work with fp16 as well
op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
op::nnvm_to_onnx::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
dtype};

trt_output_map->emplace(node_name, out_tuple);
output_map->emplace(node_name, out_tuple);

auto graph_out = graph_proto->add_output();
auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
Expand Down
38 changes: 2 additions & 36 deletions src/operator/contrib/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>

#include <algorithm>
#include <iostream>
Expand All @@ -49,6 +48,7 @@
#include <utility>
#include <string>

#include "nnvm_to_onnx-inl.h"
#include "../operator_common.h"
#include "../../common/utils.h"
#include "../../common/serialization.h"
Expand All @@ -60,49 +60,15 @@ namespace mxnet {
namespace op {

using namespace nnvm;
using namespace ::onnx;
using int64 = ::google::protobuf::int64;

namespace tensorrt {
enum class TypeIO { Inputs = 0, Outputs = 1 };
using NameToIdx_t = std::map<std::string, int32_t>;
using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
} // namespace tensorrt

using trt_name_to_idx = std::map<std::string, uint32_t>;

struct TRTParam : public dmlc::Parameter<TRTParam> {
std::string serialized_onnx_graph;
std::string serialized_input_map;
std::string serialized_output_map;
tensorrt::NameToIdx_t input_map;
tensorrt::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

TRTParam() {}

TRTParam(const ::onnx::ModelProto& onnx_graph,
const tensorrt::InferenceMap_t& input_map,
const tensorrt::NameToIdx_t& output_map) {
common::Serialize(input_map, &serialized_input_map);
common::Serialize(output_map, &serialized_output_map);
onnx_graph.SerializeToString(&serialized_onnx_graph);
}

DMLC_DECLARE_PARAMETER(TRTParam) {
DMLC_DECLARE_FIELD(serialized_onnx_graph)
.describe("Serialized ONNX graph");
DMLC_DECLARE_FIELD(serialized_input_map)
.describe("Map from inputs to topological order as input.");
DMLC_DECLARE_FIELD(serialized_output_map)
.describe("Map from outputs to order in g.outputs.");
}
};

struct TRTEngineParam {
nvinfer1::IExecutionContext* trt_executor;
std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
std::vector<std::pair<uint32_t, nnvm_to_onnx::TypeIO> > binding_map;
};

} // namespace op
Expand Down
28 changes: 13 additions & 15 deletions src/operator/contrib/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,18 @@
namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(TRTParam);

OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
tensorrt::NameToIdx_t input_map,
tensorrt::NameToIdx_t output_map) {
nnvm_to_onnx::NameToIdx_t input_map,
nnvm_to_onnx::NameToIdx_t output_map) {
TRTEngineParam param;
for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
const std::string& binding_name = trt_engine->getBindingName(b);
if (trt_engine->bindingIsInput(b)) {
param.binding_map.emplace_back(input_map[binding_name],
tensorrt::TypeIO::Inputs);
nnvm_to_onnx::TypeIO::Inputs);
} else {
param.binding_map.emplace_back(output_map[binding_name],
tensorrt::TypeIO::Outputs);
nnvm_to_onnx::TypeIO::Outputs);
}
}
param.trt_executor = trt_engine->createExecutionContext();
Expand All @@ -67,7 +65,7 @@ OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
const std::vector<TShape>& /*ishape*/,
const std::vector<int>& /*itype*/) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);

::onnx::ModelProto model_proto;
bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
Expand All @@ -82,15 +80,15 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
node_param.serialized_onnx_graph, batch_size, 1 << 30);

tensorrt::NameToIdx_t output_map;
nnvm_to_onnx::NameToIdx_t output_map;
for (auto& el : node_param.output_map) {
output_map[el.first] = std::get<0>(el.second);
}
return GetPtrMapping(trt_engine, node_param.input_map, output_map);
}

void TRTParamParser(nnvm::NodeAttrs* attrs) {
TRTParam param_;
ONNXParam param_;

try {
param_.Init(attrs->dict);
Expand All @@ -114,7 +112,7 @@ void TRTParamParser(nnvm::NodeAttrs* attrs) {

inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
std::vector<TShape>* out_shape) {
const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto &node_param = nnvm::get<ONNXParam>(attrs.parsed);
for (auto& el : node_param.output_map) {
(*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
}
Expand All @@ -131,7 +129,7 @@ inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask

inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
std::vector<int>* out_dtype) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
for (auto& el : node_param.output_map) {
(*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
}
Expand All @@ -140,7 +138,7 @@ inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,

inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
std::vector<std::string> output;
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
output.resize(node_param.input_map.size());
for (auto& el : node_param.input_map) {
output[el.second] = el.first;
Expand All @@ -150,7 +148,7 @@ inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {

inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
std::vector<std::string> output;
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
output.resize(node_param.output_map.size());
for (auto& el : node_param.output_map) {
output[std::get<0>(el.second)] = el.first;
Expand All @@ -162,11 +160,11 @@ NNVM_REGISTER_OP(_trt_op)
.describe(R"code(TRT operation (one engine)
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
return node_param.input_map.size();
})
.set_num_outputs([](const NodeAttrs& attrs) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
return node_param.output_map.size();
})
.set_attr_parser(TRTParamParser)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/tensorrt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
std::vector<void*> bindings;
bindings.reserve(param.binding_map.size());
for (auto& p : param.binding_map) {
if (p.second == tensorrt::TypeIO::Inputs) {
if (p.second == nnvm_to_onnx::TypeIO::Inputs) {
bindings.emplace_back(inputs[p.first].dptr_);
} else {
bindings.emplace_back(outputs[p.first].dptr_);
Expand Down