Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir:python] Small optimization to get_op_result_or_results. #123866

Merged
merged 1 commit into from
Jan 22, 2025

Conversation

hawkinsp
Copy link
Contributor

  • We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway.
  • If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.

* We can call .results without figuring out whether we have an Operation
  or an OpView, and that's likely the common case anyway.
* If we have one or more results, we can return them directly, with no
  need for a call to get_op_result_or_value. We're guaranteed that
  .results returns a PyOpResultList, so we have either an OpResult or
  sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.
@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

Changes
  • We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway.
  • If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects.

This saves a few 100ms during IR construction in an LLM JAX benchmark.


Full diff: /~https://github.com/llvm/llvm-project/pull/123866.diff

1 Files Affected:

  • (modified) mlir/python/mlir/dialects/_ods_common.py (+10-9)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..5b67ab03d6f494 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -133,15 +133,16 @@ def get_op_results_or_values(
 def get_op_result_or_op_results(
     op: _Union[_cext.ir.OpView, _cext.ir.Operation],
 ) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
-    if isinstance(op, _cext.ir.OpView):
-        op = op.operation
-    return (
-        list(get_op_results_or_values(op))
-        if len(op.results) > 1
-        else get_op_result_or_value(op)
-        if len(op.results) > 0
-        else op
-    )
+    results = op.results
+    num_results = len(results)
+    if num_results == 1:
+        return results[0]
+    elif num_results > 1:
+        return results
+    elif isinstance(op, _cext.ir.OpView):
+        return op.operation
+    else:
+        return op
 
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks

@jpienaar jpienaar merged commit ff0f1dd into llvm:main Jan 22, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants