Skip to content

Commit

Permalink
Enable shape hints during infer_shape pass (#107)
Browse files Browse the repository at this point in the history
* enable shape hints during infer_shape pass

* fix comment
  • Loading branch information
Ziheng Jiang authored and piiswrong committed Mar 10, 2017
1 parent 85aaf57 commit 0d64855
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ class IndexedGraph {
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}

/*! \return whether a node is existed in the indexed graph */
inline bool exist(const nnvm::Node* node) const {
return node2index_.count(node);
}

// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;

Expand Down
28 changes: 28 additions & 0 deletions include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ struct NodeEntry {
uint32_t version;
};

/*!
* \brief This lets you use a NodeEntry as a key in a unordered_map of the form
* unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
*/
struct NodeEntryHash {
size_t operator()(const NodeEntry& e) const {
return std::hash<Node*>()(e.node.get()) ^
(std::hash<size_t>()(e.index) << 1 >> 1) ^
(std::hash<size_t>()(e.version) << 1);
}
};

/*!
* \brief This lets you use a NodeEntry as a key in a unordered_map of the form
* unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
*/
struct NodeEntryEqual {
size_t operator()(const NodeEntry& a, const NodeEntry& b) const {
return (a.node.get() == b.node.get()) &&
(a.index == b.index) &&
(a.version == b.version);
}
};

/*! use NodeEntry as key in unordered_map */
template<typename ValueType>
using NodeEntryMap = std::unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>;

/*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
Expand Down
15 changes: 14 additions & 1 deletion src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ Graph InferAttr(Graph &&ret,
ret.attrs.erase(input_name);
}

// get the shape hints
std::string shape_hints_key = std::string(attr_name) + "_hints";
if (ret.attrs.count(shape_hints_key)) {
NodeEntryMap<AttrType> shape_hints =
ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
for (const auto& kv : shape_hints) {
NodeEntry e = kv.first;
if (idx.exist(e.node.get())) {
rshape[idx.entry_id(kv.first)] = kv.second;
}
}
}

std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
Expand All @@ -75,7 +88,7 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (is_backward.get(inode.source->op(), false)) {
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
Expand Down

0 comments on commit 0d64855

Please sign in to comment.