Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention to speedup fused_gate_attention. #52731

Merged
merged 58 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
49ec05c
first commit
JamesLim-sy Apr 10, 2023
76b87c2
fix some bugs
JamesLim-sy Apr 18, 2023
a26df02
fix some bugs
JamesLim-sy Apr 18, 2023
cc08899
fix bugs in flashattn.h
JamesLim-sy Apr 19, 2023
5bee3a6
:qix pointer bugs for my yesterday errors
JamesLim-sy Apr 19, 2023
afce9d6
add for backward
JamesLim-sy Apr 23, 2023
9f76b5f
fix bugs for backward
JamesLim-sy Apr 23, 2023
05b3444
04-24 first commit
JamesLim-sy Apr 24, 2023
3e41967
Merge branch 'develop' into add_flash_attn_for_af2
Xreki Apr 24, 2023
69a80cf
fix code conflicts
JamesLim-sy Apr 24, 2023
9d0befe
Reorganize the forward codes of flash-attention.
Xreki Apr 24, 2023
66f07bc
Fix forward.
Xreki Apr 24, 2023
d66d507
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
Xreki Apr 24, 2023
387d26f
Remove some noused codes.
Xreki Apr 24, 2023
cf4a1c8
Change backward.
Xreki Apr 24, 2023
2092b92
Merge branch 'develop' into add_flash_attn_for_af2
Xreki Apr 25, 2023
ade7a07
Simplify codes.
Xreki Apr 25, 2023
1ddf939
Simplify codes and fix backward.
Xreki Apr 25, 2023
b1668b0
Fix calling for tensor.data.
Xreki Apr 25, 2023
7f9d905
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
JamesLim-sy Apr 25, 2023
3b97303
Change all LOG(INFO) to VLOG and fix the backward.
Xreki Apr 25, 2023
5a9f08d
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
JamesLim-sy Apr 25, 2023
6b9b49c
add scale for AF2 flash_attn, much thanks to xreki and shaojie for de…
JamesLim-sy Apr 27, 2023
2e31f24
much thanks to xreki and shaojie for debuging this flash_attn in Alph…
JamesLim-sy Apr 27, 2023
3262172
much thanks to xreki and shaojie for debuging this flash_attn in Alph…
JamesLim-sy Apr 27, 2023
8c02766
decrease the effect of debug print on performance
JamesLim-sy Apr 27, 2023
6d389d4
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 5, 2023
20cdc33
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
Xreki May 5, 2023
bd321df
Unify the initialize of flashattn arguments.
Xreki May 5, 2023
08a8b75
Rewirte the reshape of temp_mask and temp_bias.
Xreki May 5, 2023
6a65ee0
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 6, 2023
4682c0d
Update commit and fix reduce_dim.
Xreki May 6, 2023
165afab
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 6, 2023
dd2860e
API support use_flash_attn.
Xreki May 6, 2023
fe80730
Fix compiling error on CI.
Xreki May 8, 2023
184cf9a
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 8, 2023
65c6ed1
fix tag commit id with newest
JamesLim-sy May 8, 2023
462e36e
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy May 8, 2023
3458f8c
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy May 8, 2023
be44a91
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy May 8, 2023
a9ba1ba
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
Xreki May 8, 2023
c0f497a
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
JamesLim-sy May 8, 2023
bfe5a8c
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy May 8, 2023
0b7fda0
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
JamesLim-sy May 8, 2023
ac3ff47
change flash_attn commit id
JamesLim-sy May 8, 2023
ce937f6
fix op unitest for fused_gate_attn
JamesLim-sy May 8, 2023
23c108f
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 10, 2023
ad3f70a
Merge branch 'add_flash_attn_for_af2' of /~https://github.com/JamesLim-…
Xreki May 11, 2023
7ff9f5e
Try to crop the flash-attention lib.
Xreki May 11, 2023
e92a9bb
Correct the condition of whether can use flash-attn.
Xreki May 11, 2023
df8c302
Remove the softmax_out argument.
Xreki May 12, 2023
fc3c281
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 12, 2023
d01c89c
Remove is_causal.
Xreki May 12, 2023
91a0ea5
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 17, 2023
f6be954
Polish codes.
Xreki May 17, 2023
ba84941
Fix qkv_transpose_out's shape and scaling of Q * K.
Xreki May 18, 2023
bee8537
Merge branch 'develop' into add_flash_attn_for_af2
Xreki May 18, 2023
3747978
Update commit of flash-attention.
Xreki May 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
fused_gate_attention_dygraph_function(
const paddle::Tensor& Query,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
fused_gate_attention_dygraph_function(
const paddle::Tensor& Query,
Expand Down Expand Up @@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function(
{"SoftmaxOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxLse",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
Expand Down Expand Up @@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut);
paddle::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::Tensor SoftmaxLse;
egr::EagerUtils::GetOutput(outs["SoftmaxLse"][0], &SoftmaxLse);
paddle::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::Tensor GateOut;
Expand Down Expand Up @@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function(
p_autograd_Out);
// Create GradOpNode
auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(8, 12));
new fused_gate_attentionGradNodeCompat(9, 12));

bool merge_qkv = true;
if (attrs.count("merge_qkv")) {
Expand All @@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function(
has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating"));
}

bool use_flash_attn = false;
if (attrs.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attrs.at("use_flash_attn"));
}

// Set Attributes
grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs));
Expand Down Expand Up @@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function(
grad_node->SetGradOutMeta(NonbatchedBias, 6);
}

if (use_flash_attn) {
grad_node->SetTensorWrapperSoftmaxLse(SoftmaxLse);
grad_node->SetTensorWrapperSrcMask(SrcMask);
grad_node->SetGradOutMeta(SrcMask, 7);
}

egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0);
grad_node->SetGradInMeta(QueryTransposeOut, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1);
Expand All @@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function(
ValueTransposeOut,
QKVTransposeOut,
SoftmaxOut,
SoftmaxLse,
FMHAOut,
GateOut,
Out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()(
has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating"));
}

bool use_flash_attn = false;
if (attr_map_.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attr_map_.at("use_flash_attn"));
}

std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"FMHAOut",
egr::EagerUtils::TrySyncToVars(
Expand Down Expand Up @@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()(
egr::Controller::Instance().GenerateUniqueName())};
}

if (use_flash_attn) {
auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_);
ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
auto SoftmaxLse = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxLse_);
ins0["SoftmaxLse"] = egr::EagerUtils::TrySyncToVars(SoftmaxLse);
}

auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
GateOut_.clear();
GateWeight_.clear();
NonbatchedBias_.clear();
SrcMask_.clear();
OutLinearBias_.clear();
OutLinearWeight_.clear();
QKVTransposeOut_.clear();
QKVWeight_.clear();
Query_.clear();
SoftmaxOut_.clear();
SoftmaxLse_.clear();
Key_.clear();
QueryWeight_.clear();
KeyWeight_.clear();
Expand Down Expand Up @@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) {
NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false);
}
void SetTensorWrapperSrcMask(const paddle::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
}
void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
}
Expand All @@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
}
void SetTensorWrapperSoftmaxLse(const paddle::Tensor& SoftmaxLse) {
SoftmaxLse_ = egr::TensorWrapper(SoftmaxLse, false);
}
void SetTensorWrapperKey(const paddle::Tensor& Key) {
Key_ = egr::TensorWrapper(Key, false);
}
Expand Down Expand Up @@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
egr::TensorWrapper GateOut_;
egr::TensorWrapper GateWeight_;
egr::TensorWrapper NonbatchedBias_;
egr::TensorWrapper SrcMask_;
egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearWeight_;
egr::TensorWrapper QKVTransposeOut_;
egr::TensorWrapper QKVWeight_;
egr::TensorWrapper Query_;
egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper SoftmaxLse_;

egr::TensorWrapper Key_;
egr::TensorWrapper QueryWeight_;
Expand Down
Loading