-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hybrid] Fix model parallel non-distributed param broadcast #36186
[hybrid] Fix model parallel non-distributed param broadcast #36186
Conversation
Thanks for your contribution! |
rings.append(self.dp_ring_id) | ||
|
||
# need sync non distributed param in mp group | ||
if self.mp_ring_id is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放另一个地方会好一些吧? mp 的初始化同步为什么会放到 offload 中实现?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为需要先把参数给广播好,然后再插入cast、memcpy op,否则会造成各个卡的fp16参数和offload变量不一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也可以直接在_initialization_broadcast里再写一段逻辑专门处理offload和optimize_cast需要先广播参数的需求,可能麻烦一些,不过从模块化角度来说,确实要好一些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我之后再专门搞个逻辑处理处理这个需求吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
Others
Describe
1、混合并行中,mp的非distributed的参数需要保持各个mp rank一致,PR修复mp的非distributed参数未广播的问题。
2、开启optimize_offload或(optimize_cast+optimize_sharding)时,会将param设置为 非persistable,在program.clone时,会出错。鉴于该问题目前只在hybrid中存在,PR通过将非persistable param重新生成为var的方式修复该问题。