diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 762aa58f7298..5662ef5b45a6 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -75,6 +75,33 @@ class SimplifyReshape : public DFPatternRewrite { DFPattern x_; }; +/*! + * \brief SimplifyCast matches the pattern of cast data to the same dtype. + */ +class SimplifyCast : public DFPatternRewrite { + public: + SimplifyCast() { + data_pat_ = IsWildcard(); + like_pat_ = IsWildcard(); + pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + const TensorTypeNode* data_ty = call->args[0]->checked_type().as(); + const TensorTypeNode* like_ty = pre->checked_type().as(); + if (like_ty->dtype == data_ty->dtype) { + return node_map[data_pat_][0]; + } + return post; + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -321,6 +348,17 @@ class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { } }; +class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + // `like_pat_` here is `fill_value` + return MakeFull(node_map[like_pat_][0], shape, dtype); + } +}; + class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { public: ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} @@ -439,12 +477,14 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { DFPatternRewriteComposer composer; composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); } diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index d015cdd36c2d..d1dffa34578b 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -236,6 +236,20 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) +def test_simplify_cast(): + dtype = "int32" + data = relay.var("data", shape=(3, 4, 5), dtype=dtype) + expr1 = relay.cast(data, dtype) + dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype) + expr2 = relay.cast_like(data, dtype_like) + + expected = run_infer_type(data) + actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual1, expected) + actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual2, expected) + + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") @@ -276,6 +290,17 @@ def test_concretize_ones_like(): assert tvm.ir.structural_equal(actual, expected) +def test_concretize_full_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + fill_value = relay.var("fill", relay.TensorType((), "float32")) + expr = relay.full_like(shape_like, fill_value) + + expected = run_infer_type(relay.full(fill_value, (3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_collapse_sum_like(): data = relay.var("data", shape=(3, 3, 3), dtype="float32") shape_like = relay.var("shape_like", shape=(3,), dtype="float32")