-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix topological sort for sync nodes execution order #7431
Conversation
889f679
to
cbc4d25
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks good. I have a few ideas for how we can further improve this in the future.
A quick note about the PR description. I'm not sure iff
in "OrderParent < OrderChild
iff parent -> child
" is correct. I think if
is sufficient. Let me rewrite this with generic node names so I can better explain:
Order(a) < Order(b)
iffa -> b
Since this is using iff
, i.e. if and only if
, it means
(
Order(a) < Order(b)
ifa -> b
) and (Order(a) < Order(b)
only ifa -> b
)
(Order(a) < Order(b)
ifa -> b
) and ((notOrder(a) < Order(b)
) if (nota -> b
))
(Order(a) < Order(b)
ifa -> b
) and (Order(a) >= Order(b)
if (nota -> b
))
This means if b
is not a child of a
, then a
must come after b
in the order, but this creates a contradiction if a
and b
have no parent - child relationship. Each must come after the other (note that two nodes can not have the same order).
usedNodes.has(typeof node === 'string' ? node : node.name); | ||
|
||
function unique(nodes: Node[]): Node[] { | ||
return [...new Map(nodes.map((node) => [node.name, node])).values()]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a given node name, is the Node
object corresponding to it always the same reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think yes, but I won't do that because 1) it is not an contract that node objects with the same name in different map/array are always equal in reference. 2) The project-wide convention in tfjs-converter is to lookup nodes by name string instead of object reference, and I don't want to make it wrong by breaking the convention.
while (frontier.length > 0) { | ||
const node = frontier.pop(); | ||
seen.add(node.name); | ||
if (!weightMap[node.name]) { | ||
orderedNodes.push(node); | ||
const nodeName = frontier.pop(); | ||
const node = nameToNode.get(nodeName)!; | ||
for (const child of node.children.filter(isUsed)) { | ||
if (--inCounts[child.name] === 0) { | ||
orderedNodeNames.push(child.name); | ||
frontier.push(child.name); | ||
} | ||
} | ||
node.children.forEach(child => { | ||
if (!seen.has(child.name) && usedNodes.has(child.name) && | ||
child.inputs.every(input => seen.has(input.name))) { | ||
frontier.push(child); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to verify that there are no cycles in the graph, or is that checked elsewhere (I realize the original algorithm doesn't do that either, so I won't block the PR on this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's checked in GraphExecutor.compile
in graph_executor.ts
, with the assumption that if a graph does not have dynamic node it's going to be a tree.
// a --> b --> c --> d | ||
// when node `c` is predefined (e.g. given as an input tensor), we can | ||
// skip node `a` and `b` since their outputs will never be used. | ||
// TODO: Filter out more nodes when >=2 nodes are predefined in a path. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove all nodes from the graph where any one of these conditions holds:
- The node can not be reached from any predefined node by following children pointers. During graph execution, this node will never be executed, so it doesn't need to be in the graph.
- The node can not be reached from any output node by following parent pointers. In this case, none of the outputs depend on this node, so it does not need to be in the graph.
All other nodes are reachable by forward traversal from a predefined node and by reverse traversal from an output node. Therefore, each of them lies on a path from a predefined node to an output node, and is therefore necessary for computing the value of the output node.
We can compute this set of reachable nodes in O(V + E)
by doing a DFS (caching each visited node) from the predefined nodes on the original graph and another DFS from the output nodes on the reversed graph (which we construct by swapping the direction of children pointers to create parent pointers). Then, we take the set intersection of both of these results.
It might be worth abstracting some of the graph logic away from the tfjs-specific logic to make writing and testing these graph operations easier. I'd be happy to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function intends to make the execution plan, without changing the graph. And its' upstream mark unused nodes by providing a set of names in usedNodes
instead of modifying the graph (removing nodes from children or inputs). So I don't want to break the contract, unless there is a plan to refactor everything starting from reading graph from a file.
At the time of making the execution plan, we don't know which nodes are output nodes. Therefore it just assumes all predefined reachable nodes to be output nodes. The output node list are provided by external caller in GraphExecutor.execute
through:
- Parameter
outputs?: string[]
- Env flag
KEEP_INTERMEDIATE_TENSORS
Therefore, it is not something fixed with the graph. Of course we can make different execution plan for different set of "output node names", but I can see why the original author did not do that: imagine a user is debugging a weird bug by adding intermediate name to param outputs, which can make a different plan and cause the execution behavior to be different.
Besides, I don't think abstracting out graph logic (like topological order solver) is a good idea. It would end up with lots of redundant conversions between tfjs graph and the abstract graph type for the solver input/output, and make it harder to extend/modify for new tfjs graph specific properties and requirements. For now this function is self-contained for the purpose of making the execution plan (execution plan is not limited to topological order, so it may still better to rename this function with its purpose instead of it's implementation), and there are no other functions need the solver. So I think it's okay to keep it like this.
But I think wrapping it in a TFJSGraph class as a method will definitely be better. It will even better if we can separate SyncGraph and AsyncGraph with different checks/planner/executor. Go ahead if you have any good ideas to refactor this part!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. I'm happy to leave the graph as-is for this PR. We can talk more offline about how best to represent the graph in the future.
{ | ||
const nodeNameToOrder = new Map<string, number>( | ||
filteredOrderedNodes.map((node, order) => [node.name, order])); | ||
const predefinedNodeNames = | ||
new Set(predefinedNodes.map((node) => node.name)); | ||
const isPredefined = (node: Node|string) => | ||
predefinedNodeNames.has(typeof node === 'string' ? node : node.name); | ||
const willBeExecutedNodeNames = | ||
new Set(filteredOrderedNodes.map((node) => node.name)); | ||
const willBeExecuted = (node: Node|string) => willBeExecutedNodeNames.has( | ||
typeof node === 'string' ? node : node.name); | ||
|
||
for (const node of filteredOrderedNodes) { | ||
for (const child of node.children.filter(willBeExecuted)) { | ||
if (!nodeNameToOrder.has(child.name)) { | ||
throw new Error('TopologicalSortError: Child is unreachable.'); | ||
} | ||
if (nodeNameToOrder.get(node.name) > nodeNameToOrder.get(child.name)) { | ||
throw new Error( | ||
'TopologicalSortError: Node has greater order than its child.'); | ||
} | ||
} | ||
if (!isPredefined(node)) { | ||
for (const input of node.inputs) { | ||
if (!nodeNameToOrder.has(input.name)) { | ||
throw new Error('TopologicalSortError: Input is unreachable.'); | ||
} | ||
if (nodeNameToOrder.get(input.name) > | ||
nodeNameToOrder.get(node.name)) { | ||
throw new Error( | ||
'TopologicalSortError: Node has smaller order than its input.'); | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do this validation here, or should this be integrated into a test suite instead? I guess it's better for the user to see one of these errors instead of running an incorrect graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a e2e test suite to test the execution results of models in our model pool?
throw new Error('TopologicalSortError: Child is unreachable.'); | ||
} | ||
if (nodeNameToOrder.get(node.name) > nodeNameToOrder.get(child.name)) { | ||
throw new Error( | ||
'TopologicalSortError: Node has greater order than its child.'); | ||
} | ||
} | ||
if (!isPredefined(node)) { | ||
for (const input of node.inputs) { | ||
if (!nodeNameToOrder.has(input.name)) { | ||
throw new Error('TopologicalSortError: Input is unreachable.'); | ||
} | ||
if (nodeNameToOrder.get(input.name) > | ||
nodeNameToOrder.get(node.name)) { | ||
throw new Error( | ||
'TopologicalSortError: Node has smaller order than its input.'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we log the name of each node or child we reference here to make debugging easier?
@pyu10055 Please take a look when you get a chance. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @chunnienc and @mattsoulanille)
tfjs-converter/src/executor/model_analysis.ts
line 103 at r3 (raw file):
* need to be executed to compute the output. */ export function getNodesInTopologicalOrder(
Nit: This method is relatively long, it would be great if you can abstract the key steps into separate function.
@mattsoulanille Could you review the helper functions and doc comments again? Thanks. |
Co-authored-by: Matthew Soulanille <matthew@soulanille.net>
Co-authored-by: Matthew Soulanille <matthew@soulanille.net>
Co-authored-by: Matthew Soulanille <matthew@soulanille.net>
This PR rewrites the algorithm to determine execution order of sync graph nodes:
OrderParent < OrderChild
ifparent -> child
always holds. The old implementation does not guarantee this whenchild
is predefined (given with an input tensor, is weight, etc), which will lead to some error in execution-order based tensor disposal plan.O(NE)
toO(E)
.Verified with sample models (localbenchmark) and diffuser
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is