-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hybrid] Support tensor parallel and cache structure for fused attention op. #40101
[hybrid] Support tensor parallel and cache structure for fused attention op. #40101
Conversation
Thanks for your contribution! |
e89ce26
to
ab370ad
Compare
bf2cb02
to
a8db839
Compare
81ca93e
to
c750011
Compare
ce9d167
to
5b878bf
Compare
cache structure support for fuse attention
5b878bf
to
e088d36
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for
set_tests_properties(test_static_model_parallel_fused_attention PROPERTIES TIMEOUT 120)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
fuse_attention_op 修改如下:
增加了
CacheKV
input(optional),用作生成模型while当中的上一轮cache的值使用。增加了
CacheKVOut
output,作为生成模型while当中的本轮更新过后的cache的值使用。修改 attribute,新增
ring_id
属性(optional,默认值-1),作为分布式训练tensor parallel的通讯组标识使用。Update fused_attention op support tensor model parallel and cache structure.
For tensor model parallel, first column parallel linear, then row parallel linear, then we will get partial out, we can use allreduce to get the final output.