-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform][Bugfix] Handle non-composite lambda functions in FuseOps #16598
[Transform][Bugfix] Handle non-composite lambda functions in FuseOps #16598
Conversation
Prior to this commit, calling `FuseOpsByPattern` with `annotate_codegen=True` would cause an error when encountering a lambda function. This was caused by the `CompositeFunctionAnnotator` asserting that all `relax::Function` encountered must have the `kComposite` attribute. While this is true for all lambda functions produced by `FuseOpsByPattern`, the user may have defined other lambda functions as well. This commit updates `CompositeFunctionAnnotator` to ignore lambda functions that do not have a `kComposite` attribute.
@@ -1238,10 +1238,14 @@ class CompositeFunctionAnnotator : public ExprMutator { | |||
|
|||
Expr VisitExpr_(const FunctionNode* func_node) final { | |||
Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node)); | |||
auto composite_name = func_node->GetAttr<String>(attr::kComposite); | |||
|
|||
if (!func_node->GetAttr<String>(attr::kComposite)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are non-composite functions visited?
tvm/src/relax/transform/fuse_ops.cc
Line 1224 in 76c1708
auto new_func = Downcast<Function>(VisitExpr(func)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PatternBasedPartitioner
only visits non-composite functions, produces a composite function for each pattern match, and updates the non-composite function to call the newly-generated composite function. Afterwards, the call to CompositeFunctionAnnotator
is called. This visits only non-composite functions, finds any relax-to-relax function calls, and asserts that the callee is composite.
The callee will be composite for every function call generated by PatternBasedPartitioner
, but that doesn't guarantee that all relax-to-relax function calls have a composite callee. If the IRModule
contains a relax-to-relax call prior to PatternBasedPartitioner
, that callee may be non-composite. This IRModule would be entirely legal, but would trigger the assert in CompositeFunctionAnnotator
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the problem isn't with calls to inner functions as on line 1224, but with calls to other functions within the IRModule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the problem is if the callee is not a global var, the callee function will still be visited, so the fix makes sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, you're right on that one. It's if there is a inner function in the input IRModule
. (Apologies, trying to track too many PRs at one time.)
Prior to this commit, calling
FuseOpsByPattern
withannotate_codegen=True
would cause an error when encountering a lambda function. This was caused by theCompositeFunctionAnnotator
asserting that allrelax::Function
encountered must have thekComposite
attribute. While this is true for all lambda functions produced byFuseOpsByPattern
, the user may have defined other lambda functions as well.This commit updates
CompositeFunctionAnnotator
to ignore lambda functions that do not have akComposite
attribute.