diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index c14482affbde..e4dd3f6677e4 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -322,6 +322,272 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, return ret; } +template +nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, + const nnvm::TShape empty_val, + const char* infer_name, + const char* input_name, + const char* attr_key_name, + const char* attr_name, + const char* unknown_name, + IsNone fis_none, + FDefault fdefault, + bool bwd_identity_assign, + const char* dispatch_mode_name, + const DispatchMode default_mode_val = DispatchMode::kUndefined) { + using nnvm::IndexedGraph; + using nnvm::Op; + using AttrType = nnvm::TShape; + using FInferType = nnvm::FInferShape; + using AttrVector = std::vector; + using NodeAttrVector = std::vector; + using dmlc::any; + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = + Op::GetAttr(infer_name); + static auto& is_backward = + Op::GetAttr("TIsBackward"); + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + // reshape shape vector + AttrVector rshape; + // dispatch mode vector + DispatchModeVector dispatch_modes; + if (ret.attrs.count(attr_name) != 0) { + rshape = ret.MoveCopyAttr(attr_name); + } else { + rshape.resize(idx.num_node_entries(), empty_val); + } + + if (ret.attrs.count(input_name) != 0) { + const AttrVector& shape_args = ret.GetAttr(input_name); + CHECK_LE(shape_args.size(), idx.input_nodes().size()) + << "More provided " << attr_name << "s than number of arguments."; + for (size_t i = 0; i < shape_args.size(); ++i) { + rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; + } + } + + // get the shape hints + std::string shape_hints_key = std::string(attr_name) + "_hints"; + if (ret.attrs.count(shape_hints_key)) { + nnvm::NodeEntryMap shape_hints = + ret.GetAttr>(shape_hints_key); + for (const auto& kv : shape_hints) { + nnvm::NodeEntry e = kv.first; + if (idx.exist(e.node.get())) { + rshape[idx.entry_id(kv.first)] = kv.second; + } + } + } + + std::string shape_attr_key; + if (ret.attrs.count(attr_key_name) != 0) { + shape_attr_key = ret.GetAttr(attr_key_name); + // erase the provided arguments + ret.attrs.erase(attr_key_name); + } + + // limit inference to part of the graph + uint32_t node_start = 0, node_end = idx.num_nodes(); + if (ret.attrs.count("node_range")) { + const auto& range = ret.GetAttr >("node_range"); + node_start = range.first; + node_end = range.second; + CHECK_GE(node_start, 0); + CHECK_LE(node_end, idx.num_nodes()); + ret.attrs.erase("node_range"); + } + uint32_t entry_start = 0, entry_end = idx.num_node_entries(); + if (ret.attrs.count("entry_range")) { + const auto& range = ret.GetAttr >("entry_range"); + entry_start = range.first; + entry_end = range.second; + CHECK_GE(entry_start, 0); + CHECK_LE(entry_end, idx.num_node_entries()); + ret.attrs.erase("entry_range"); + } + // populate the node attribute vector + if (dispatch_mode_name != nullptr) { + if (ret.attrs.count(dispatch_mode_name) != 0) { + dispatch_modes = ret.MoveCopyAttr(dispatch_mode_name); + } else { + LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph"; + } + } + + // Temp space for shape inference. + std::vector ishape, oshape; + // whether a shape is dynamic + std::vector is_dynamic(rshape.size(), 0); + // inference step function for nid + auto infer_step = [&](uint32_t nid, bool last_iter) { + const auto& inode = idx[nid]; + const std::string name = inode.source->attrs.name; + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); + if (inode.source->is_variable()) { + // Variable node. No operator. Only one output entry. + CHECK(inode.source->op() == nullptr); + CHECK_EQ(num_outputs, 1U); + const uint32_t out_ent_id = idx.entry_id(nid, 0); + if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { + auto it = inode.source->attrs.dict.find(shape_attr_key); + if (it != inode.source->attrs.dict.end()) { + std::istringstream is(it->second); + CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; + } + } + // assign a default value to node attribute + if (dispatch_mode_name != nullptr) { + op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); + } + } else if (is_backward.get(inode.source->op(), false) && + inode.control_deps.size() && bwd_identity_assign) { + CHECK(dispatch_mode_name == nullptr) + << "Backward inference for node attributes is not available"; + CHECK_GE(inode.control_deps.size(), 1U) + << "BackwardOp need to have control_deps to its forward op"; + const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; + CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Input gradient assignement + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { + // Need to skip empty forward shape, because it may not be + // available now and it is possible to infer the forward + // shape in one of the next a few passes + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + // out grad entries + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; + } + } + } + } else { + DispatchMode* dispatch_mode = nullptr; + bool forward_known = true; + // Forward operator inference. + ishape.resize(num_inputs, empty_val); + bool is_input_dynamic_shape = false; + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) { + is_input_dynamic_shape = true; + } + if (fis_none(ishape[i])) forward_known = false; + } + oshape.resize(num_outputs, empty_val); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + if (fis_none(oshape[i])) forward_known = false; + } + if (dispatch_mode_name != nullptr) { + dispatch_mode = &dispatch_modes[nid]; + if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; + } + auto finfer = finfer_shape.get(inode.source->op(), fdefault); + if (finfer == nullptr || is_input_dynamic_shape) { + for (uint32_t i = 0; i < oshape.size(); ++i) { + if (oshape[i].ndim() == 0) { + is_dynamic[idx.entry_id(nid, i)] = 1; + } + } + } else if (!forward_known) { + if (finfer != nullptr) { + // Call inference function of the operator. + try { + forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape, dispatch_mode); + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + CHECK(!last_iter) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name + << " we are not able to complete the inference because of this"; + } + } + // Save to the result map. + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } + } + }; + + size_t last_num_unknown; + size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; + size_t num_unknown_entry_attr = entry_end - entry_start; + size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; + int i = 0; + do { + if (i % 2 == 0) { + for (uint32_t nid = node_start; nid < node_end; ++nid) { + infer_step(nid, false); + } + } else { + // backward inference + for (uint32_t i = node_end; i != node_start; --i) { + infer_step(i - 1, false); + } + } + last_num_unknown = num_unknown; + num_unknown = 0; + for (size_t j = entry_start; j < entry_end; ++j) { + if (fis_none(rshape[j])) { + ++num_unknown; + } + } + if (dispatch_mode_name) { + for (size_t i = node_start; i < node_end; i++) { + if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; + } + } + ++i; + } while (num_unknown > 0 && last_num_unknown > num_unknown); + // set the shapes + ret.attrs[attr_name] = std::make_shared(std::move(rshape)); + // set the shapes + if (dispatch_mode_name) { + ret.attrs[dispatch_mode_name] = std::make_shared(std::move(dispatch_modes)); + } + // number of nodes who knows the shape. + ret.attrs[unknown_name] = std::make_shared(num_unknown); + return ret; +} + nnvm::Graph InferShape(nnvm::Graph&& graph, nnvm::ShapeVector&& shape_inputs, const std::string& shape_attr_key) { @@ -332,7 +598,7 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, if (shape_attr_key.length() != 0) { graph.attrs["shape_attr_key"] = std::make_shared(shape_attr_key); } - return InferAttr( + return InferShapeAttr( std::move(graph), nnvm::TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes",