Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Expose BlockBuilder's Analyzer instance in Python #17548

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -779,6 +780,9 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);

static runtime::TypedPackedFunc<PackedFunc(std::string)> AsFunc(
std::shared_ptr<Analyzer> analyzer);
};

} // namespace arith
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/relax/utils.h>
#include <tvm/runtime/object.h>

#include <string>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -256,6 +258,12 @@ class BlockBuilderNode : public Object {
*/
virtual arith::Analyzer* GetAnalyzer() = 0;

/*!
* \brief Returns the analyzer wrapped in a closure, suitable for FFI purposes.
* \return A function that exposes the methods of the underlying Analyzer instance.
*/
virtual runtime::TypedPackedFunc<PackedFunc(std::string)> GetAnalyzerAsFunc() = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.BlockBuilder";
TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object);
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ class Analyzer:
be used to perform various symbolic integer analysis.
"""

def __init__(self):
_mod = _ffi_api.CreateAnalyzer()
def __init__(self, _mod_ctor=None):
if _mod_ctor is None:
_mod_ctor = _ffi_api.CreateAnalyzer
_mod = _mod_ctor()
self._const_int_bound = _mod("const_int_bound")
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
# pylint: disable=no-else-return, invalid-name, unused-argument, import-outside-toplevel
"""Developer API of constructing Relax AST."""

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import tvm
from tvm import relax as rx
from tvm import tir
from tvm.arith.analyzer import Analyzer
from tvm.ir.module import IRModule
from tvm.runtime import Object

Expand Down Expand Up @@ -163,6 +165,7 @@ def __init__(self, mod: IRModule = None):
# Which functions are currently being defined
self._func_stack: List[FunctionScope] = []
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore
self._analyzer = Analyzer(partial(_ffi_api.BlockBuilderGetAnalyzer, self))

def _begin_dataflow_block(self) -> None:
_ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore
Expand Down Expand Up @@ -797,3 +800,6 @@ def end_scope(self) -> None:
"""End the current scope. Please see `begin_scope` for details"""

return _ffi_api.BlockBuilderEndScope(self) # type: ignore

def get_analyzer(self) -> Analyzer:
return self._analyzer
9 changes: 6 additions & 3 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
return res;
}

TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
TypedPackedFunc<PackedFunc(std::string)> Analyzer::AsFunc(std::shared_ptr<Analyzer> self) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto self = std::make_shared<Analyzer>();
auto f = [self](std::string name) -> PackedFunc {
if (name == "const_int_bound") {
return PackedFunc(
Expand Down Expand Up @@ -347,7 +346,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
return TypedPackedFunc<PackedFunc(std::string)>(f);
}

TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Analyzer::AsFunc(std::make_shared<Analyzer>());
});

} // namespace arith
Expand Down
20 changes: 15 additions & 5 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <vector>

#include "../../node/ndarray_hash_equal.h"
#include "tvm/runtime/packed_func.h"

// Block builder have three categories of logics that are interdependent with each other.
//
Expand Down Expand Up @@ -217,11 +218,11 @@ class BlockBuilderImpl : public BlockBuilderNode {
// of shape inference. In many cases, knowning that the
// shape variable is non-negative allows for simpler
// expressions for dynamic shapes.
analyzer_.MarkGlobalNonNegValue(shape_var);
analyzer_->MarkGlobalNonNegValue(shape_var);
} else {
const PrimExpr& old_shape_expr = (*it).second;
CHECK(old_shape_expr.same_as(shape_expr) ||
analyzer_.CanProveEqual(old_shape_expr, shape_expr))
analyzer_->CanProveEqual(old_shape_expr, shape_expr))
<< "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs "
<< shape_expr;
}
Expand Down Expand Up @@ -304,7 +305,11 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
}

arith::Analyzer* GetAnalyzer() final { return &analyzer_; }
arith::Analyzer* GetAnalyzer() final { return analyzer_.get(); }

runtime::TypedPackedFunc<PackedFunc(std::string)> GetAnalyzerAsFunc() final {
return arith::Analyzer::AsFunc(analyzer_);
}

protected:
/*!
Expand Down Expand Up @@ -361,7 +366,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
IRModule context_mod_;

/*! \brief Internal analzyer */
arith::Analyzer analyzer_;
std::shared_ptr<arith::Analyzer> analyzer_ = std::make_shared<arith::Analyzer>();

/*!
* \return The current frame.
Expand Down Expand Up @@ -852,7 +857,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
auto opt = MatchStructInfo<FuncStructInfo>(call->op);
ICHECK(opt) << "Call->op must contains a function struct info";
FuncStructInfo finfo = opt.value();
return DeriveCallRetStructInfo(finfo, call, GetRef<BlockBuilder>(this), &analyzer_);
return DeriveCallRetStructInfo(finfo, call, GetRef<BlockBuilder>(this), analyzer_.get());
}
}

Expand Down Expand Up @@ -1097,6 +1102,11 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName")
/*add_underscore*/ false);
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderGetAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
BlockBuilder block_builder = args[0];
*ret = block_builder->GetAnalyzerAsFunc();
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction")
.set_body_method<BlockBuilder>(&BlockBuilderNode::AddFunction);

Expand Down
9 changes: 9 additions & 0 deletions tests/python/relax/test_blockbuilder_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,5 +990,14 @@ def subroutine_c(arg: R.Tensor) -> R.Tensor:
bb.update_func(gvar_b, subroutine_c)


def test_analyzer_ref():
bb = rx.BlockBuilder()
analyzer = bb.get_analyzer()
x, y = te.var("x"), te.var("y")
m = analyzer.modular_set(x * 6 + y * 4)
assert m.coeff == 2
assert m.base == 0


if __name__ == "__main__":
tvm.testing.main()
Loading