diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 6d07058c7b4a07..9b805cb891a569 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -29,29 +29,29 @@ struct ReductionPattern { SingleReductionOpPattern reduction_op_pattern; }; -// Stmt := IS | R | PS -// ops in StmtPattern will be lowered into a inlined cuda code. -template -using StmtPattern = std::variant, ReductionPattern, PartialShardablePattern>; - -// Stmts := [Stmt] -template -using StmtsPattern = std::list; - -// fuse rules: -// 1. IS * IS -> IS -// 2. PS * PS -> PS -// 3. IS * PS -> PS -// 4. IS * R -> R -// 5. PS * R -> R - -// lifting rules: -// 1. R -> Stmts -// 2. PS -> Stmts -// 3. Stmts * Stmts -> Stmts - -// OpTopoPattern := Error | Stmts -template -using OpTopoPattern = std::variant, StmtsPattern>; +// // Stmt := IS | R | PS +// // ops in StmtPattern will be lowered into a inlined cuda code. +// template +// using StmtPattern = std::variant, ReductionPattern, PartialShardablePattern>; + +// // Stmts := [Stmt] +// template +// using StmtsPattern = std::list>; + +// // fuse rules: +// // 1. IS * IS -> IS +// // 2. PS * PS -> PS +// // 3. IS * PS -> PS +// // 4. IS * R -> R +// // 5. PS * R -> R + +// // lifting rules: +// // 1. R -> Stmts +// // 2. PS -> Stmts +// // 3. Stmts * Stmts -> Stmts + +// // OpTopoPattern := Error | Stmts +// template +// using OpTopoPattern = std::variant, StmtsPattern>; } diff --git a/paddle/cinn/frontend/CMakeLists.txt b/paddle/cinn/frontend/CMakeLists.txt index e04ae9e9851c0a..3360b9620edb53 100755 --- a/paddle/cinn/frontend/CMakeLists.txt +++ b/paddle/cinn/frontend/CMakeLists.txt @@ -10,7 +10,8 @@ gather_srcs( op_mapper_registry.cc paddle_model_convertor.cc program_pass.cc - optimize.cc) + optimize.cc + group_pattern_util.cc) if(NOT WITH_CUDA) cinn_cc_test( diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index cb7e52f1bc8cd6..5fcfebc3df68cd 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -3,40 +3,38 @@ #include #include #include +#include +#include #include "paddle/cinn/api/op_topo_pattern.h" #include "paddle/pir/include/core/operation.h" +#include "glog/logging.h" -namespace cinn::frontend { +namespace cinn::api { struct FrontendPattern {}; -} - -namespace cinn::api { - template<> -struct ErrorPattern { - explicit ErrorPattern(const ErrorPattern& other) = default; +struct ErrorPattern { + explicit ErrorPattern(const ErrorPattern& other) = default; std::vector ops; std::string error_string; }; template<> -struct InjectiveSourcePattern { - explicit InjectiveSourcePattern(const InjectiveSourcePattern& other) = default; +struct InjectiveSourcePattern { + explicit InjectiveSourcePattern(const InjectiveSourcePattern& other) = default; std::vector ops; }; template<> -struct SingleReductionOpPattern { - explicit SingleReductionOpPattern(const SingleReductionOpPattern& other) = default; +struct SingleReductionOpPattern { + explicit SingleReductionOpPattern(const SingleReductionOpPattern& other) = default; const pir::Operation* reduce_op; }; - struct ShardableAxis { int axis; - std::optional axis_name; + std::string axis_name; bool operator==(const ShardableAxis& other) const { return this->axis == other.axis && this->axis_name == other.axis_name; @@ -51,7 +49,7 @@ struct ShardableAxis { using ShardableAxes = std::vector; struct ShardableAxesUtil { - using OldName2NewName = std::unorderd_map; + using OldName2NewName = std::unordered_map; static OldName2NewName GetOldName2NewName(const ShardableAxes& old_sa, const ShardableAxes& new_sa) { OldName2NewName old_name2new_name; @@ -69,7 +67,7 @@ struct ShardableAxesUtil { for (auto iter = sa->begin(); iter != sa->end();) { const auto& pair_it = old2new.find(iter->axis_name); if (pair_it != old2new.end()) { - iter->axis_name = pair_it.second; + iter->axis_name = pair_it->second; ++iter; } else { iter = sa->erase(iter); @@ -109,8 +107,8 @@ struct ShardableAxesSignature { }; template<> -struct PartialShardablePattern { - explicit PartialShardablePattern(const PartialShardablePattern& other) = default; +struct PartialShardablePattern { + explicit PartialShardablePattern(const PartialShardablePattern& other) = default; std::vector ops; ShardableAxesSignature shardable_axes_signature; @@ -119,8 +117,12 @@ struct PartialShardablePattern { } namespace cinn::frontend { +using IS = api::InjectiveSourcePattern; +using R = api::ReductionPattern; +using PS = api::PartialShardablePattern; -using GroupPattern = api::OpTopoPattern; -using ErrorGroupPattern = api::ErrorPattern; +using StmtPattern = std::variant; +using ErrorGroupPattern = api::ErrorPattern; +using GroupPattern = std::variant; } \ No newline at end of file diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index ae3cb963280441..8f560c3342e48a 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -5,36 +5,30 @@ #include #include #include +#include namespace cinn::frontend { namespace { - -using IS = api::InjectiveSourcePattern; -using R = api::ReductionPattern; -using PS = api::PartialShardablePattern; -using StmtPattern = api::StmtPattern; using OpPatternKind = cinn::hlir::framework::OpPatternKind; +using StmtIter = std::list::iterator; +using OpVisitor = std::function; +using NodeVisitor = std::function; + + OpPatternKind GetOpPatternKind(const ::pir::Operation* node) { return hlir::framework::pir::CompatibleInfo::OpKind(*node); } -std::function MakeGetterOrderValue4Op(const cinn::dialect::FusionOp& fusion_op) { - std::unordered_map op2order_in_block; - size_t order = 0; - for (const pir::Operation* op : fusion_op.block()->ops()) { - op2order_in_block[op] = ++order; - } - return [map=std::move(op2order_in_block)](const pir::Operation* op) { - const auto& iter = map.find(op); - CHECK(iter != map.end()); - return iter->second; - }; +bool IsGeneralInjective(const pir::Operation* op) { + hlir::framework::OpPatternKind op_pattern_kind = GetOpPatternKind(op); + return op_pattern_kind == hlir::framework::kElementWise + || op_pattern_kind == hlir::framework::kBroadcast + || op_pattern_kind == hlir::framework::kInjective; } - -bool IsISPattern(const StmtPattern& pattern){ +bool IsISPattern(StmtPattern& pattern){ return std::holds_alternative(pattern); } @@ -46,6 +40,47 @@ bool IsRPattern(const StmtPattern& pattern){ return std::holds_alternative(pattern); } +void VisitInputOp(const pir::Operation* op, const OpVisitor& DoEach) { + for (int i = 0; i < op->num_operands(); ++i) { + const auto* input_op = op->operand_source(i).defining_op(); + DoEach(input_op); + } +} + +void VisitOutputOp(const pir::Operation* op, const OpVisitor& DoEach) { + for (int i = 0; i < op->num_results(); ++i) { + pir::Value output = op->result(i); + for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) { + const auto* consumer_op = consumer_it->owner(); + DoEach(consumer_op); + } + } +} + +template +void VisitStmtOpImpl(const IS& injective_source, const DoEachT& DoEach) { + for (const auto* op : injective_source.ops) { + DoEach(op); + } +} + +template +void VisitStmtOpImpl(const R& reduce, const DoEachT& DoEach) { + DoEach(reduce.reduce_op); +} + +template +void VisitStmtOpImpl(const PS& partial_shardable, const DoEachT& DoEach) { + for (const auto* op : partial_shardable.ops) { + DoEach(op); + } +} + +template +void VisitStmtOp(const StmtPattern& stmt, const DoEachT& DoEach) { + std::visit([&](const auto& impl) { VisitStmtOpImpl(impl, DoEach); }, stmt); +} + std::function MakePredicatorIsInThisFusionOp(const cinn::dialect::FusionOp& fusion_op) { std::set set; for (const pir::Operation* op : fusion_op.block()->ops()) { @@ -58,47 +93,26 @@ std::function MakePredicatorIsInThisFusionOp(const }; } -bool IsGeneralInjective(const pir::Operation* op) { - hlir::framework::OpPatternKind op_pattern_kind = GetOpPatternKind(op); - return op_pattern_kind == hlir::framework::kElementWise - || op_pattern_kind == hlir::framework::kBroadcast - || op_pattern_kind == hlir::framework::kInjective; -} - std::function MakePredicatorIsInjectiveSource( const cinn::dialect::FusionOp& fusion_op, const std::function& IsInThisFusionOp) { - using NodeVisitor = std::function; - const auto VisitEachInput = [&](const pir::Operation* op, const NodeVisitor& DoEach) { - for (int i = 0; i < op->num_operands(); ++i) { - const auto* input_op = op->operand_source(i).defining_op(); - if (IsInThisFusionOp(input_op)) { - DoEach(input_op); - } - } - }; - const auto VisitEachOutput = [&](const pir::Operation* op, const NodeVisitor& DoEach) { - for (int i = 0; i < op->num_results(); ++i) { - pir::Value output = op->result(i); - for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) { - const auto* consumer_op = consumer_it->owner(); - if (IsInThisFusionOp(consumer_op)) { - DoEach(consumer_op); + + const auto& IsSource = [&](const pir::Operation* op) { + std::size_t num_inputs = 0; + VisitInputOp(op, + [&](const pir::Operation* input) { + if(IsInThisFusionOp(input)){ + ++num_inputs; } } - } + ); + return num_inputs == 0; }; const auto starts = [&]{ - const auto& IsSource = [&](const pir::Operation* op) { - std::size_t num_inputs = 0; - VisitEachInput([&](const pir::Operation*) { ++num_inputs}); - return num_inputs == 0; - }; std::list starts; for (const auto* op : fusion_op.GetOperators()) { - if (!IsInThisFusionOp(op)) continue; - if (IsSource(op)) { + if (!IsInThisFusionOp(op) && IsSource(op)) { starts.push_back(op); } else { // do nothing. @@ -111,9 +125,13 @@ std::function MakePredicatorIsInjectiveSource( auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) { bool is_inputs_all_injective_source = true; - VisitEachInput(op, [&](const pir::Operation* input){ - is_inputs_all_injective_source = (is_inputs_all_injective_source && op_2_is_injective_source.at(input)); - }); + VisitInputOp(op, + [&](const pir::Operation* input){ + if (IsInThisFusionOp(input)){ + is_inputs_all_injective_source = (is_inputs_all_injective_source && op_2_is_injective_source.at(input)); + } + } + ); return is_inputs_all_injective_source; }; @@ -138,7 +156,7 @@ class StmtFusionHelper { std::list ConvertToStmtsPattern() const { std::list ret; - for (const auto* op : fusion_op_.block()->ops()) { + for (const auto* op : fusion_op_.GetOperators()) { if (!IsInThisFusionOp(op)) continue; ret.emplace_back(ConvertToStmtPattern(op)); } @@ -190,7 +208,6 @@ class StmtFusionHelper { std::optional Fuse_IS_x_PS_2_PS(std::list* stmt_patterns) const { return FuseFilteredStmtPatterns(stmt_patterns); } - struct FusePolicy_IS_x_R_2_R { static bool FuseCondition(const StmtPattern& upstream, const StmtPattern& downstream) { return IsISPattern(upstream) && IsRPattern(downstream); @@ -246,10 +263,41 @@ class StmtFusionHelper { } private: - using StmtIter = std::list::iterator; + + StmtPattern ConvertToStmtPattern(const pir::Operation* op) const { + const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); + if (IsInjectiveSource(op)) { + return ConvertToIS(op); + } else if (kind == hlir::framework::kReduction) { + return ConvertReductionOpToReductionPattern(op); + } else if (kind == hlir::framework::kElementWise) { + return ConvertOpToPS(op); + } else if (kind == hlir::framework::kBroadcast) { + return ConvertOpToPS(op); + } else { + LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->op_name(); + } + LOG(FATAL) << "Dead code"; + } + + IS ConvertToIS(const pir::Operation* op) const { + return IS{{op}}; + } + + R ConvertReductionOpToReductionPattern(const pir::Operation* op) const { + return R{{}, {op}}; + } + + PS ConvertOpToPS(const pir::Operation* op) const { + const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); + return PS{ + .ops={op}, + .shardable_axes_signature=MakeShardableAxesSignature4Op(op), + }; + } static std::function(const pir::Operation*)> - MakeGetterStmt4Op(std::list* stmts) const { + MakeStmtFinderFromOp(std::list* stmts) { std::unordered_map op2stmt_iter; for (auto iter = stmts->begin(); iter != stmts->end(); ++iter) { VisitStmtOp(*iter, [&](const auto* op) { op2stmt_iter[op] = iter; }); @@ -261,28 +309,17 @@ class StmtFusionHelper { }; } - template - void VisitStmtOpImpl(const IS& injective_source, const DoEachT& DoEach) const { - for (const auto* op : injective_source.ops) { - DoEach(op); + std::function MakeTopoOrderFinderOfOp(cinn::dialect::FusionOp& fusion_op) const { + std::unordered_map op2order_in_block; + size_t order = 0; + for (const pir::Operation* op : fusion_op.GetOperators()) { + op2order_in_block[op] = ++order; } - } - - template - void VisitStmtOpImpl(const R& reduce, const DoEachT& DoEach) const { - DoEach(reduce.reduce_op); - } - - template - void VisitStmtOpImpl(const PS& partial_shardable, const DoEachT& DoEach) const { - for (const auto* op : partial_shardable.ops) { - DoEach(op); - } - } - - template - void VisitStmtOp(const StmtPattern& stmt, const DoEachT& DoEach) const { - std::visit([&](const auto& impl) { VisitStmtOpImpl(impl, DoEach); }, stmt); + return [map=std::move(op2order_in_block)](const pir::Operation* op) { + const auto& iter = map.find(op); + CHECK(iter != map.end()); + return iter->second; + }; } template @@ -290,13 +327,13 @@ class StmtFusionHelper { const IsDetailPatternT& IsDetailPattern, const ConstructPatternT& ConstructPattern, std::list* stmts) const { - const auto StmtIter4Op = MakeGetterStmt4Op(stmts); - using NodeVisitor = std::function; + const auto StmtFinder = MakeStmtFinderFromOp(stmts); + const auto VisitInputStmt = [&](StmtIter stmt, const NodeVisitor& DoEach) { VisitStmtOp(*stmt, [&](const auto* op){ VisitInputOp(op, [&](const pir::Operation* input) { - if (const auto& input_stmt = StmtIter4Op(input)) { - if (IsDetailPattern(*input_stmt.value())) { + if (const auto& input_stmt = StmtFinder(input)) { + if (IsDetailPattern(input_stmt->value())) { DoEach(input_stmt.value()); } } @@ -306,7 +343,7 @@ class StmtFusionHelper { const auto VisitOutputStmt = [&](StmtIter stmt, const NodeVisitor& DoEach) { VisitStmtOp(*stmt, [&](const auto* op){ VisitOutputOp(op, [&](const pir::Operation* output) { - if (const auto& output_stmt = StmtIter4Op(output)) { + if (const auto& output_stmt = StmtFinder(output)) { if (IsDetailPattern(*output_stmt.value())) { DoEach(output_stmt.value()); } @@ -322,12 +359,12 @@ class StmtFusionHelper { }); return num_injective_src_outputs == 0; }; - const auto GetOrder = MakeGetterOrderValue4Op(fusion_op_); + const auto GetOrder = MakeTopoOrderFinderOfOp(fusion_op_); const auto Cmp = [&](const auto* lhs, const auto& rhs) { return GetOrder(lhs) < GetOrder(rhs); }; common::BfsWalker reverse_walker(VisitInputStmt); - const auto& GetVisitedOps = [&](const auto stmt_iter) { + const auto& GetUpstreamOps = [&](const auto stmt_iter) { std::vector visited_ops; reverse_walker(start, [&](const auto node){ VisitStmtOp(node, [&](const auto* op) { visited_ops.push_back(op); }); @@ -338,7 +375,7 @@ class StmtFusionHelper { std::list fused_stmts; for (auto stmt_iter = stmts->begin(); stmt_iter != stmts->end(); ++stmt_iter) { if (!IsSinkPattern(stmt_iter)) continue; - fused_stmts.emplace_back(ConstructPattern(GetVisitedOps(stmt_iter))); + fused_stmts.emplace_back(ConstructPattern(GetUpstreamOps(stmt_iter))); } for (auto stmt_iter = stmts->begin(); stmt_iter != start->end();) { if (IsDetailPattern(*stmt_iter)) { @@ -350,66 +387,11 @@ class StmtFusionHelper { stmts->splice(stmts->begin(), std::move(fused_stmts)); return std::nullopt; } - - using OpVisitor = std::function; - - void VisitInputOp(const pir::Operation* op, const OpVisitor& DoEach) const { - for (int i = 0; i < op->num_operands(); ++i) { - const auto* input_op = op->operand_source(i).defining_op(); - if (IsInThisFusionOp(input_op)) { - DoEach(input_op); - } - } - } - - void VisitOutputOp(const pir::Operation* op, const OpVisitor& DoEach) const { - for (int i = 0; i < op->num_results(); ++i) { - pir::Value output = op->result(i); - for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) { - const auto* consumer_op = consumer_it->owner(); - if (IsInThisFusionOp(consumer_op)) { - DoEach(consumer_op); - } - } - } - } - - StmtPattern ConvertToStmtPattern(const pir::Operation* op) const { - const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); - if (IsInjectiveSource(op)) { - return ConvertToIS(op); - } else if (kind == hlir::framework::kReduction) { - return ConvertReductionOpToReductionPattern(op); - } else if (kind == hlir::framework::kElementWise) { - return ConvertOpToPS(op); - } else if (kind == hlir::framework::kBroadcast) { - return ConvertOpToPS(op); - } else { - LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->op_name(); - } - LOG(FATAL) << "Dead code"; - } - - IS ConvertToIS(const pir::Operation* op) const { - return IS{{op}}; - } - - R ConvertReductionOpToReductionPattern(const pir::Operation* op) const { - return R{{}, {op}}; - } size_t GetRank(pir::Value value) const { return value.type().dyn_cast().dims().size(); }; - PS ConvertOpToPS(const pir::Operation* op) const { - const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); - return PS{ - .ops={op}, - .shardable_axes_signature=MakeShardableAxesSignature4Op(op), - }; - } - ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) const { const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); if (kind == hlir::framework::kElementWise) { @@ -462,6 +444,28 @@ class StmtFusionHelper { StmtIter downstream_iter; }; + bool IsConnected(const StmtIter& upstream, const StmtIter& downstream){ + const auto StmtFinder = MakeStmtFinderFromOp({*upstream, *downstream}); + const auto VisitInputStmt = [&](StmtIter stmt, const NodeVisitor& DoEach) { + VisitStmtOp(*stmt, [&](const auto* op)){ + VisitInputOp(op, [&](const pir::Operation* input) { + if (const auto& input_stmt = StmtFinder(input)) { + if (IsDetailPattern(input_stmt->value())) { + DoEach(input_stmt.value()); + } + } + }); + }; + }; + + auto downstream_input_patterns = std::unordered_set(); + VisitInputStmt(*downstream, [&](const StmtIter& input_pattern){ + downstream_input_patterns.insert(input_pattern); + }) + + return downstream_input_patterns.count(upstream) > 0; + } + template std::optional FindConnetedPattenPairWithCondition( std::list* stmt_patterns, diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 9273a722e25c51..394dea68c112e6 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -77,6 +77,7 @@ class IR_API FusionOp : public pir::Op { pir::Block *block(); std::vector GetOperators(); + std::vector GetOperators() const; void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT