-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support GQA #7906
support GQA #7906
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## develop #7906 +/- ##
========================================
Coverage 56.56% 56.56%
========================================
Files 589 589
Lines 89939 89964 +25
========================================
+ Hits 50874 50891 +17
- Misses 39065 39073 +8 ☔ View full report in Codecov by Sentry. |
query_states, _, _ = fused_rotary_position_embedding( | ||
query_states, | ||
None, | ||
None, | ||
sin=sin, | ||
cos=cos, | ||
position_ids=position_ids, | ||
use_neox_rotary_style=False, | ||
) | ||
key_states, _, _ = fused_rotary_position_embedding( | ||
key_states, | ||
None, | ||
None, | ||
sin=sin, | ||
cos=cos, | ||
position_ids=position_ids, | ||
use_neox_rotary_style=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GQA的时候,原来的代码 用 fused_rotary_position_embedding 是有问题的吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,Paddle旧版本的fused_rotary_position_embedding不支持传入的q和k/v 有不同的heads,所以等效的方式是单独处理q,k,需要分别调用2次接口。
我们在dev已经做了支持,所以可以直接调用1次接口。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimizationPR changes
OthersDescription
support GQA目前Paddle的develop已经为fa和fuse_rope支持gqa/mqa,但2.6及以下版本未支持,因此为模型实现兼容性修改。
设置num_key_value_heads = 8,使用Paddle2.6版本和Paddle-dev分别测试llama2-13b模型收敛情况。