-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2stat]support pure fp16 for dy2stat #36944
[Dy2stat]support pure fp16 for dy2stat #36944
Conversation
Thanks for your contribution! |
@@ -118,6 +118,17 @@ def _in_amp_guard(): | |||
return False | |||
|
|||
|
|||
def _in_pure_fp16_guard(): | |||
tracer = _dygraph_tracer() | |||
if tracer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if tracer: | |
return tracer and tracer._amp_level == core.AmpLevel.O2 |
""" | ||
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode. | ||
""" | ||
with fluid.dygraph.guard(place): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认已经开启动态图了,不需要这个guard了
dygraph_loss = self.train(to_static=False) | ||
self.assertTrue( | ||
np.allclose(static_loss, dygraph_loss), | ||
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确保CI稳定没有问题,若放开tol,需要加NOTE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* run dy2stat pure fp16 in Linear model * no use self._pure_fp16_inputs * add test and fix Adam error in dy2stat pure fp16 training * use paddle.optimizer.Adam * run test in gpu * change test time for CI * enlarge atol for test_resnet_pure_fp16 * refine code and enlarge atol * make custom_white_list and custom_black_list take effect for AMP and pure fp16 * check tracer is not None * use default atol * change filter_size * change atol and add some NOTE
PR types
New features
PR changes
Others
Describe
本PR在
CastPureFp16Inputs
函数中加入了特判逻辑:对于动转静中所调用的run_program op直接跳过,不再进行后续的cast操作。因为run_program op只有FP32 kernel,且我们在动转静pure fp16训练时在调用run_program op时已经将run_program op的输入cast为fp16类型,所以在CastPureFp16Inputs
应该跳过对于run_program op的处理,避免将其输入再cast回fp32类型。1e-3
,以确保单测通过。对于动转静pure fp16训练,已经在mnist和resnet网络上测试了网络的收敛性, 均可以正常收敛。paddle.amp.auto_cast
接口中黑白名单(custom_white_list参数和custom_black_list参数)的设置。在此PR之前在动转静AMP训练中设置这两个参数并不会生效。