Skip to content

Commit

Permalink
Fix for optimize_for multiple subgraph properties issue (apache#19263)
Browse files Browse the repository at this point in the history
* initial commit

* fixed whitespace

Co-authored-by: Ubuntu <ubuntu@ip-172-31-6-220.us-west-2.compute.internal>
  • Loading branch information
2 people authored and bgawrych committed Apr 8, 2021
1 parent 5b7826e commit 346528e
Showing 1 changed file with 110 additions and 101 deletions.
211 changes: 110 additions & 101 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1372,137 +1372,145 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
size_t num_forward_inputs = input_names.size();

// create a data structure from pointer array
std::unordered_map<std::string, std::string> options_map;
for (mx_uint i = 0; i < num_options; ++i)
options_map.emplace(keys[i], vals[i]);

NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);

if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
auto init_graph = [&](auto s) {
nnvm::Graph g = Symbol2Graph(*s);
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
size_t num_forward_inputs = input_names.size();

if (args_len || aux_len) {
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}

size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
}
}
args_top++;
}
args_top++;
}
}

g.attrs["context"] = std::make_shared<nnvm::any>(
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
} else {
// args/aux were not specified, so set nullptr/empty-lists
NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> arg_names;
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> aux_names;
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
} else {
// args/aux were not specified, so set nullptr/empty-lists
NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> arg_names;
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> aux_names;
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
// create a data structure from pointer array
std::unordered_map<std::string, std::string> options_map;
for (mx_uint i = 0; i < num_options; ++i)
options_map.emplace(keys[i], vals[i]);

// set dedup option as attribute on graph to enable dedup during partitioning
if (options_map.count("dedup_subgraph") > 0 &&
options_map.at("dedup_subgraph").compare("True") == 0)
g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
// set dedup option as attribute on graph to enable dedup during partitioning
if (options_map.count("dedup_subgraph") > 0 &&
options_map.at("dedup_subgraph").compare("True") == 0)
g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
return g;
};

if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
// use subgraph backend
const auto backend = mxnet::op::SubgraphBackendRegistry
::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
nnvm::Graph g = init_graph(s);
property->PrePartition(g, options_map);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
g.attrs.erase("subgraph_property");
property->PostPartition(g);
s->outputs = g.outputs;
}
} else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
// use graph pass
nnvm::Graph g = init_graph(s);
g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
g = ApplyPass(std::move(g), backend_name);
Expand All @@ -1515,6 +1523,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
g.attrs.erase("new_aux");
g.attrs.erase("new_arg_names");
g.attrs.erase("new_aux_names");
s->outputs = g.outputs;

NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
Expand Down Expand Up @@ -1546,7 +1555,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
// cannot find graph pass or subgraph backend registered in this name
LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
}
s->outputs = g.outputs;

*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}

0 comments on commit 346528e

Please sign in to comment.