Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#42 from feifei-111/cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
Cinn trivalop fuse
  • Loading branch information
feifei-111 authored Mar 9, 2024
2 parents badeae6 + 83d1e79 commit b0d6347
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 184 deletions.
48 changes: 24 additions & 24 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,29 @@ struct ReductionPattern {
SingleReductionOpPattern<T> reduction_op_pattern;
};

// Stmt := IS | R | PS
// ops in StmtPattern will be lowered into a inlined cuda code.
template <typename T>
using StmtPattern = std::variant<InjectiveSourcePattern<T>, ReductionPattern<T>, PartialShardablePattern<T>>;

// Stmts := [Stmt]
template <typename T>
using StmtsPattern = std::list<StmtPattern>;

// fuse rules:
// 1. IS * IS -> IS
// 2. PS * PS -> PS
// 3. IS * PS -> PS
// 4. IS * R -> R
// 5. PS * R -> R

// lifting rules:
// 1. R -> Stmts
// 2. PS -> Stmts
// 3. Stmts * Stmts -> Stmts

// OpTopoPattern := Error | Stmts
template <typename T>
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtsPattern<T>>;
// // Stmt := IS | R | PS
// // ops in StmtPattern will be lowered into a inlined cuda code.
// template <typename T>
// using StmtPattern = std::variant<InjectiveSourcePattern<T>, ReductionPattern<T>, PartialShardablePattern<T>>;

// // Stmts := [Stmt]
// template <typename T>
// using StmtsPattern = std::list<StmtPattern<T>>;

// // fuse rules:
// // 1. IS * IS -> IS
// // 2. PS * PS -> PS
// // 3. IS * PS -> PS
// // 4. IS * R -> R
// // 5. PS * R -> R

// // lifting rules:
// // 1. R -> Stmts
// // 2. PS -> Stmts
// // 3. Stmts * Stmts -> Stmts

// // OpTopoPattern := Error | Stmts
// template <typename T>
// using OpTopoPattern = std::variant<ErrorPattern<T>, StmtsPattern<T>>;

}
3 changes: 2 additions & 1 deletion paddle/cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ gather_srcs(
op_mapper_registry.cc
paddle_model_convertor.cc
program_pass.cc
optimize.cc)
optimize.cc
group_pattern_util.cc)

if(NOT WITH_CUDA)
cinn_cc_test(
Expand Down
40 changes: 21 additions & 19 deletions paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,38 @@
#include <unordered_map>
#include <atomic>
#include <vector>
#include <unordered_map>
#include <variant>
#include "paddle/cinn/api/op_topo_pattern.h"
#include "paddle/pir/include/core/operation.h"
#include "glog/logging.h"

namespace cinn::frontend {
namespace cinn::api {

struct FrontendPattern {};

}

namespace cinn::api {

template<>
struct ErrorPattern<frontend::FrontendPattern> {
explicit ErrorPattern(const ErrorPattern<frontend::FrontendPatterns>& other) = default;
struct ErrorPattern<FrontendPattern> {
explicit ErrorPattern(const ErrorPattern<FrontendPattern>& other) = default;

std::vector<const pir::Operation*> ops;
std::string error_string;
};

template<>
struct InjectiveSourcePattern<frontend::FrontendPattern> {
explicit InjectiveSourcePattern(const InjectiveSourcePattern<frontend::FrontendPatterns>& other) = default;
struct InjectiveSourcePattern<FrontendPattern> {
explicit InjectiveSourcePattern(const InjectiveSourcePattern<FrontendPattern>& other) = default;
std::vector<const pir::Operation*> ops;
};

template<>
struct SingleReductionOpPattern<frontend::FrontendPattern> {
explicit SingleReductionOpPattern(const SingleReductionOpPattern<frontend::FrontendPatterns>& other) = default;
struct SingleReductionOpPattern<FrontendPattern> {
explicit SingleReductionOpPattern(const SingleReductionOpPattern<FrontendPattern>& other) = default;
const pir::Operation* reduce_op;
};

struct ShardableAxis {
int axis;
std::optional<std::string> axis_name;
std::string axis_name;

bool operator==(const ShardableAxis& other) const {
return this->axis == other.axis && this->axis_name == other.axis_name;
Expand All @@ -51,7 +49,7 @@ struct ShardableAxis {
using ShardableAxes = std::vector<ShardableAxis>;

struct ShardableAxesUtil {
using OldName2NewName = std::unorderd_map<std::string, std::string>;
using OldName2NewName = std::unordered_map<std::string, std::string>;

static OldName2NewName GetOldName2NewName(const ShardableAxes& old_sa, const ShardableAxes& new_sa) {
OldName2NewName old_name2new_name;
Expand All @@ -69,7 +67,7 @@ struct ShardableAxesUtil {
for (auto iter = sa->begin(); iter != sa->end();) {
const auto& pair_it = old2new.find(iter->axis_name);
if (pair_it != old2new.end()) {
iter->axis_name = pair_it.second;
iter->axis_name = pair_it->second;
++iter;
} else {
iter = sa->erase(iter);
Expand Down Expand Up @@ -109,8 +107,8 @@ struct ShardableAxesSignature {
};

template<>
struct PartialShardablePattern<frontend::FrontendPattern> {
explicit PartialShardablePattern(const PartialShardablePattern<frontend::FrontendPatterns>& other) = default;
struct PartialShardablePattern<FrontendPattern> {
explicit PartialShardablePattern(const PartialShardablePattern<FrontendPattern>& other) = default;

std::vector<const pir::Operation*> ops;
ShardableAxesSignature shardable_axes_signature;
Expand All @@ -119,8 +117,12 @@ struct PartialShardablePattern<frontend::FrontendPattern> {
}

namespace cinn::frontend {
using IS = api::InjectiveSourcePattern<api::FrontendPattern>;
using R = api::ReductionPattern<api::FrontendPattern>;
using PS = api::PartialShardablePattern<api::FrontendPattern>;

using GroupPattern = api::OpTopoPattern<FrontendPattern>;
using ErrorGroupPattern = api::ErrorPattern<FrontendPattern>;
using StmtPattern = std::variant<IS, R, PS>;
using ErrorGroupPattern = api::ErrorPattern<api::FrontendPattern>;
using GroupPattern = std::variant<ErrorGroupPattern, StmtPattern>;

}
Loading

0 comments on commit b0d6347

Please sign in to comment.