Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support GQA #7906

Merged
merged 7 commits into from
Feb 23, 2024
Merged

support GQA #7906

merged 7 commits into from
Feb 23, 2024

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Jan 26, 2024

PR types

Performance optimization

PR changes

Others

Description

support GQA

目前Paddle的develop已经为fa和fuse_rope支持gqa/mqa,但2.6及以下版本未支持,因此为模型实现兼容性修改。

设置num_key_value_heads = 8,使用Paddle2.6版本和Paddle-dev分别测试llama2-13b模型收敛情况。
image

Copy link

codecov bot commented Jan 26, 2024

Codecov Report

Attention: 12 lines in your changes are missing coverage. Please review.

Comparison is base (7e643ad) 56.56% compared to head (1596b37) 56.56%.
Report is 5 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/llama/modeling.py 29.41% 12 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #7906   +/-   ##
========================================
  Coverage    56.56%   56.56%           
========================================
  Files          589      589           
  Lines        89939    89964   +25     
========================================
+ Hits         50874    50891   +17     
- Misses       39065    39073    +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +936 to +953
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA的时候,原来的代码 用 fused_rotary_position_embedding 是有问题的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,Paddle旧版本的fused_rotary_position_embedding不支持传入的q和k/v 有不同的heads,所以等效的方式是单独处理q,k,需要分别调用2次接口。

我们在dev已经做了支持,所以可以直接调用1次接口。

ZHUI
ZHUI previously approved these changes Feb 21, 2024
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 6a6a9fe into PaddlePaddle:develop Feb 23, 2024
5 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants