From b78a36337094541c00ac893e8133b0cde4b8528e Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Fri, 29 Nov 2024 08:01:16 +0100 Subject: [PATCH 1/5] Expose BlockBuilder's Analyzer in Python --- include/tvm/arith/analyzer.h | 7 ++++++ python/tvm/arith/analyzer.py | 3 +-- python/tvm/relax/block_builder.py | 5 +++++ src/arith/analyzer.cc | 23 +++++++++++++++++--- src/relax/ir/block_builder.cc | 6 +++++ tests/python/relax/test_blockbuilder_core.py | 9 ++++++++ 6 files changed, 48 insertions(+), 5 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 044e5d6f6ca9..7e5dfe4e3742 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -779,6 +779,13 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); + + /*! + * \brief Returns the instance wrapped in a closure, suitable for FFI purposes. + * + * \return A function that exposes the methods of the underlying Analyzer instance. + */ + runtime::TypedPackedFunc AsFunc(); }; } // namespace arith diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index f8069a717da3..6552fdd445c4 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -105,8 +105,7 @@ class Analyzer: be used to perform various symbolic integer analysis. """ - def __init__(self): - _mod = _ffi_api.CreateAnalyzer() + def __init__(self, _mod=_ffi_api.CreateAnalyzer()): self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") self._bind = _mod("bind") diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 37866840bd68..db2a78276bc2 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -22,6 +22,7 @@ import tvm from tvm import relax as rx from tvm import tir +from tvm.arith.analyzer import Analyzer from tvm.ir.module import IRModule from tvm.runtime import Object @@ -163,6 +164,7 @@ def __init__(self, mod: IRModule = None): # Which functions are currently being defined self._func_stack: List[FunctionScope] = [] self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore + self._analyzer = Analyzer(_ffi_api.BlockBuilderGetAnalyzer(self)) def _begin_dataflow_block(self) -> None: _ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore @@ -797,3 +799,6 @@ def end_scope(self) -> None: """End the current scope. Please see `begin_scope` for details""" return _ffi_api.BlockBuilderEndScope(self) # type: ignore + + def get_analyzer(self) -> Analyzer: + return self._analyzer diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 08d5e9379dc6..d6789fc8bf12 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -268,10 +268,11 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { +namespace { +template +TypedPackedFunc GetAnalyzerFunc(Ptr self) { using runtime::PackedFunc; using runtime::TypedPackedFunc; - auto self = std::make_shared(); auto f = [self](std::string name) -> PackedFunc { if (name == "const_int_bound") { return PackedFunc( @@ -347,7 +348,23 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu } return PackedFunc(); }; - *ret = TypedPackedFunc(f); + return TypedPackedFunc(f); +} + +// Duck typed smart pointer interface, allowing us to pass a non-smart ptr to GetAnalyzerFunc +template +struct Ptr { + Ptr(T* ptr) : ptr(ptr) {} + T* get() const { return ptr; } + T* operator->() const { return this->ptr; } + T* ptr; +}; +} // namespace + +TypedPackedFunc Analyzer::AsFunc() { return GetAnalyzerFunc(Ptr(this)); } + +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = GetAnalyzerFunc(std::make_shared()); }); } // namespace arith diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index b8092bbf3a4d..badf37443a37 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -39,6 +39,7 @@ #include #include "../../node/ndarray_hash_equal.h" +#include "tvm/runtime/packed_func.h" // Block builder have three categories of logics that are interdependent with each other. // @@ -1097,6 +1098,11 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") /*add_underscore*/ false); }); +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { + BlockBuilder block_builder = args[0]; + *ret = block_builder->GetAnalyzer()->AsFunc(); +}); + TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") .set_body_method(&BlockBuilderNode::AddFunction); diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index 02cf7f14c155..b7da5e4890c5 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -990,5 +990,14 @@ def subroutine_c(arg: R.Tensor) -> R.Tensor: bb.update_func(gvar_b, subroutine_c) +def test_analyzer_ref(): + bb = rx.BlockBuilder() + analyzer = bb.get_analyzer() + x, y = te.var("x"), te.var("y") + m = analyzer.modular_set(x * 6 + y * 4) + assert m.coeff == 2 + assert m.base == 0 + + if __name__ == "__main__": tvm.testing.main() From 2575dee86d434320557086589461243257782375 Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Fri, 29 Nov 2024 11:15:57 +0100 Subject: [PATCH 2/5] Make BlockBuilder's Analyzer member a shared_ptr --- include/tvm/arith/analyzer.h | 8 ++------ include/tvm/relax/block_builder.h | 6 ++++++ src/arith/analyzer.cc | 18 ++---------------- src/relax/ir/block_builder.cc | 16 ++++++++++------ 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 7e5dfe4e3742..6a80a65d9184 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -780,12 +780,8 @@ class TVM_DLL Analyzer { */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); - /*! - * \brief Returns the instance wrapped in a closure, suitable for FFI purposes. - * - * \return A function that exposes the methods of the underlying Analyzer instance. - */ - runtime::TypedPackedFunc AsFunc(); + static runtime::TypedPackedFunc AsFunc( + std::shared_ptr analyzer); }; } // namespace arith diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index ad2b9820707a..119b00edce4c 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -256,6 +256,12 @@ class BlockBuilderNode : public Object { */ virtual arith::Analyzer* GetAnalyzer() = 0; + /*! + * \brief Returns the analyzer wrapped in a closure, suitable for FFI purposes. + * \return A function that exposes the methods of the underlying Analyzer instance. + */ + virtual runtime::TypedPackedFunc GetAnalyzerAsFunc() = 0; + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.BlockBuilder"; TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index d6789fc8bf12..dff07114206d 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -268,9 +268,7 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -namespace { -template -TypedPackedFunc GetAnalyzerFunc(Ptr self) { +TypedPackedFunc Analyzer::AsFunc(std::shared_ptr self) { using runtime::PackedFunc; using runtime::TypedPackedFunc; auto f = [self](std::string name) -> PackedFunc { @@ -351,20 +349,8 @@ TypedPackedFunc GetAnalyzerFunc(Ptr self) { return TypedPackedFunc(f); } -// Duck typed smart pointer interface, allowing us to pass a non-smart ptr to GetAnalyzerFunc -template -struct Ptr { - Ptr(T* ptr) : ptr(ptr) {} - T* get() const { return ptr; } - T* operator->() const { return this->ptr; } - T* ptr; -}; -} // namespace - -TypedPackedFunc Analyzer::AsFunc() { return GetAnalyzerFunc(Ptr(this)); } - TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = GetAnalyzerFunc(std::make_shared()); + *ret = Analyzer::AsFunc(std::make_shared()); }); } // namespace arith diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index badf37443a37..c6874667c330 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -218,11 +218,11 @@ class BlockBuilderImpl : public BlockBuilderNode { // of shape inference. In many cases, knowning that the // shape variable is non-negative allows for simpler // expressions for dynamic shapes. - analyzer_.MarkGlobalNonNegValue(shape_var); + analyzer_->MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; CHECK(old_shape_expr.same_as(shape_expr) || - analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + analyzer_->CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } @@ -305,7 +305,11 @@ class BlockBuilderImpl : public BlockBuilderNode { } } - arith::Analyzer* GetAnalyzer() final { return &analyzer_; } + arith::Analyzer* GetAnalyzer() final { return analyzer_.get(); } + + runtime::TypedPackedFunc GetAnalyzerAsFunc() final { + return arith::Analyzer::AsFunc(analyzer_); + } protected: /*! @@ -362,7 +366,7 @@ class BlockBuilderImpl : public BlockBuilderNode { IRModule context_mod_; /*! \brief Internal analzyer */ - arith::Analyzer analyzer_; + std::shared_ptr analyzer_ = std::make_shared(); /*! * \return The current frame. @@ -853,7 +857,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(call->op); ICHECK(opt) << "Call->op must contains a function struct info"; FuncStructInfo finfo = opt.value(); - return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + return DeriveCallRetStructInfo(finfo, call, GetRef(this), analyzer_.get()); } } @@ -1100,7 +1104,7 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") TVM_REGISTER_GLOBAL("relax.BlockBuilderGetAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { BlockBuilder block_builder = args[0]; - *ret = block_builder->GetAnalyzer()->AsFunc(); + *ret = block_builder->GetAnalyzerAsFunc(); }); TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") From 30c651f39ba9c7de8c28e2ba52536befcfd14262 Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Tue, 3 Dec 2024 08:30:17 +0100 Subject: [PATCH 3/5] Fix Analyzer instance creation issue --- python/tvm/arith/analyzer.py | 3 ++- python/tvm/relax/block_builder.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 6552fdd445c4..295df5eef0a3 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -105,7 +105,8 @@ class Analyzer: be used to perform various symbolic integer analysis. """ - def __init__(self, _mod=_ffi_api.CreateAnalyzer()): + def __init__(self, _mod_ctor=_ffi_api.CreateAnalyzer): + _mod = _mod_ctor() self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update") self._bind = _mod("bind") diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index db2a78276bc2..608d4889034c 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -17,6 +17,7 @@ # pylint: disable=no-else-return, invalid-name, unused-argument, import-outside-toplevel """Developer API of constructing Relax AST.""" +from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Union import tvm @@ -164,7 +165,7 @@ def __init__(self, mod: IRModule = None): # Which functions are currently being defined self._func_stack: List[FunctionScope] = [] self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore - self._analyzer = Analyzer(_ffi_api.BlockBuilderGetAnalyzer(self)) + self._analyzer = Analyzer(partial(_ffi_api.BlockBuilderGetAnalyzer, self)) def _begin_dataflow_block(self) -> None: _ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore From 4957f44b2c1a6ef67143fc76f8de8fd6c0501b2a Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Tue, 3 Dec 2024 08:34:33 +0100 Subject: [PATCH 4/5] Fix linter errors --- include/tvm/arith/analyzer.h | 1 + include/tvm/relax/block_builder.h | 2 ++ 2 files changed, 3 insertions(+) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6a80a65d9184..b1111c1c17c3 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -30,6 +30,7 @@ #include #include +#include #include #include diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 119b00edce4c..2a17ad970026 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -30,6 +30,8 @@ #include #include +#include + namespace tvm { namespace relax { From f3f15169618c35227514e1dc0f1e212ad61ed421 Mon Sep 17 00:00:00 2001 From: Oskar Gustafsson Date: Thu, 5 Dec 2024 08:23:56 +0100 Subject: [PATCH 5/5] Fix DSO module loading issue --- python/tvm/arith/analyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 295df5eef0a3..9be3acc4b0c1 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -105,7 +105,9 @@ class Analyzer: be used to perform various symbolic integer analysis. """ - def __init__(self, _mod_ctor=_ffi_api.CreateAnalyzer): + def __init__(self, _mod_ctor=None): + if _mod_ctor is None: + _mod_ctor = _ffi_api.CreateAnalyzer _mod = _mod_ctor() self._const_int_bound = _mod("const_int_bound") self._const_int_bound_update = _mod("const_int_bound_update")