Skip to content

Commit

Permalink
Enhance subgraph API (apache#14113)
Browse files Browse the repository at this point in the history
* Enhance subgraph API

* Fix lint

* Trigger CI

* Fix test

* split into another PR

* Rename partition_graph to build_graph

* Fix lint

* Fix merge

* run CI

* run CI

* fix quantize script

* fix ssd script

* Address reminisce comment
  • Loading branch information
ZhennanQin authored and vdantu committed Mar 31, 2019
1 parent 3881ec0 commit e1efc8e
Show file tree
Hide file tree
Showing 11 changed files with 709 additions and 294 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ extern "C" {
* to the input graph for partitioning. This function should be
* used only for the testing purpose.
*/
MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
MXNET_DLL int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
g = ApplyPass(std::move(g), "BuildSubgraph");
s->outputs = g.outputs;
}
*ret_sym_handle = s;
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "./c_api_common.h"
#include "../operator/subgraph/subgraph_property.h"

int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
Expand All @@ -49,7 +49,7 @@ int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
property->SetAttr("graph", g);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
g = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
s->outputs = g.outputs;
}
}
Expand Down
53 changes: 26 additions & 27 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1442,30 +1442,29 @@ static nnvm::Graph InferForwardAttrs(nnvm::Graph g,

// Given input attr arrays, partition the graph using the backend name equal to prop_name.
// This is a common function for bind and simple_bind flows.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
mxnet::op::SubgraphPropertyPtr subgraph_prop,
const mxnet::ShapeVector& arg_shapes,
const nnvm::DTypeVector& arg_dtypes,
const StorageTypeVector& arg_stypes,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src,
mxnet::op::SubgraphPropertyPtr subgraph_prop,
const mxnet::ShapeVector& arg_shapes,
const nnvm::DTypeVector& arg_dtypes,
const StorageTypeVector& arg_stypes, const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
nnvm::Symbol ret = src.Copy();
nnvm::Graph g;
g.outputs = ret.outputs;
g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
aux_state_ctxes);
subgraph_prop->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
g = ApplyPass(std::move(g), "PartitionGraph");
g = ApplyPass(std::move(g), "BuildSubgraph");
ret.outputs = g.outputs;
return ret;
}

// Given input attr dicts, partition the graph using the backend name equal to prop_name.
// This is for simple_bind flow.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src,
const std::string& prop_name,
const std::unordered_map<std::string, mxnet::TShape>
& arg_shape_map,
Expand Down Expand Up @@ -1547,7 +1546,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
arg_stypes[i] = it3->second;
}
}
ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, *in_arg_ctxes, *aux_state_ctxes);
// Reorder in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types according to
// partitioned symbol input sequence
Expand All @@ -1573,13 +1572,13 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,

// Given input ndarrays, partition the graph using the backend name equal to prop_name.
// This is for bind flow.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grad_store,
std::vector<OpReqType>* grad_req_type,
std::vector<NDArray>* aux_states) {
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, const std::string& prop_name,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grad_store,
std::vector<OpReqType>* grad_req_type,
std::vector<NDArray>* aux_states) {
// setup map for in_args, arg_grad_store, grad_req_type and aux_states
std::unordered_map<std::string, NDArray> in_args_map;
std::unordered_map<std::string, NDArray> arg_grad_store_map;
Expand Down Expand Up @@ -1664,8 +1663,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& p
}
}

ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, in_arg_ctxes, aux_state_ctxes);
ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, in_arg_ctxes, aux_state_ctxes);
}
// Reorder in_args, arg_grad_store, grad_req_type and aux_states according to partitioned symbol
// input sequence
Expand Down Expand Up @@ -1713,9 +1712,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::vector<Context> tmp_aux_state_ctxes = aux_state_ctxes;
std::vector<OpReqType> tmp_grad_req_types = grad_req_types;
if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes,
&tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes);
symbol = exec::BuildSubgraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes,
&tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes);
}
exec->Init(symbol, default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes,
tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types,
Expand All @@ -1738,9 +1737,9 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
std::vector<NDArray> tmp_aux_states = aux_states;

if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), default_ctx, group2ctx,
&tmp_in_args, &tmp_arg_grad_store, &tmp_grad_req_type,
&tmp_aux_states);
symbol =
exec::BuildSubgraph(symbol, exec->subgraph_property(), default_ctx, group2ctx, &tmp_in_args,
&tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states);
}
exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, tmp_grad_req_type,
tmp_aux_states, reinterpret_cast<Executor*>(shared_exec));
Expand Down
Loading

0 comments on commit e1efc8e

Please sign in to comment.