diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 5710de5bf73f6..cec7c4d141cbe 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -22,7 +22,8 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem, Let +from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.ir import Op from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1033,10 +1034,9 @@ def visit_tuple_getitem(self, op): visit = super().visit_tuple_getitem(op) if visit.index != 0: return visit - if isinstance(visit.tuple_value, Call) and isinstance(visit.tuple_value.op, Let): - return visit if ( isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) and visit.tuple_value.op.name == "nn.dropout" and visit.index == 0 ):