Skip to content

Commit

Permalink
[Relay][Pass] SimplifyCastLike/Cast and ConcretizeFullLikeRewrite rew…
Browse files Browse the repository at this point in the history
…rites for SimplifyExpr (apache#7827)
  • Loading branch information
hgt312 authored and Trevor Morris committed May 6, 2021
1 parent 7a653ea commit 0850c4c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ class SimplifyReshape : public DFPatternRewrite {
DFPattern x_;
};

/*!
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
*/
class SimplifyCast : public DFPatternRewrite {
public:
SimplifyCast() {
data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
const TensorTypeNode* data_ty = call->args[0]->checked_type().as<TensorTypeNode>();
const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
if (like_ty->dtype == data_ty->dtype) {
return node_map[data_pat_][0];
}
return post;
}

protected:
DFPattern data_pat_;
DFPattern like_pat_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -321,6 +348,17 @@ class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
}
};

class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite {
public:
ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like")) {}

Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
DataType dtype) const override {
// `like_pat_` here is `fill_value`
return MakeFull(node_map[like_pat_][0], shape, dtype);
}
};

class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
public:
ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
Expand Down Expand Up @@ -439,12 +477,14 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
DFPatternRewriteComposer composer;
composer.AddRewrite<ConcretizeZerosLikeRewrite>();
composer.AddRewrite<ConcretizeOnesLikeRewrite>();
composer.AddRewrite<ConcretizeFullLikeRewrite>();
composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
composer.AddRewrite<SimplifyCast>();
composer.AddRewrite<FullElementwise>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ def check(x, y=None, do_nothing=False):
check(id_op(const, x), id_op(op_like(x), x))


def test_simplify_cast():
dtype = "int32"
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
expr1 = relay.cast(data, dtype)
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
expr2 = relay.cast_like(data, dtype_like)

expected = run_infer_type(data)
actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr())
assert tvm.ir.structural_equal(actual1, expected)
actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
assert tvm.ir.structural_equal(actual2, expected)


def test_concretize_reshape_like():
data = relay.var("data", shape=(2, 3, 4), dtype="float32")
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
Expand Down Expand Up @@ -276,6 +290,17 @@ def test_concretize_ones_like():
assert tvm.ir.structural_equal(actual, expected)


def test_concretize_full_like():
dtype = "int32"
shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
fill_value = relay.var("fill", relay.TensorType((), "float32"))
expr = relay.full_like(shape_like, fill_value)

expected = run_infer_type(relay.full(fill_value, (3, 4, 5), dtype))
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
assert tvm.ir.structural_equal(actual, expected)


def test_concretize_collapse_sum_like():
data = relay.var("data", shape=(3, 3, 3), dtype="float32")
shape_like = relay.var("shape_like", shape=(3,), dtype="float32")
Expand Down

0 comments on commit 0850c4c

Please sign in to comment.