Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat]support pure fp16 for dy2stat #36944

Merged
merged 15 commits into from
Nov 24, 2021

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented Nov 2, 2021

PR types

New features

PR changes

Others

Describe

  1. 支持模型动转静后进行pure fp16训练。
    本PR在CastPureFp16Inputs函数中加入了特判逻辑:对于动转静中所调用的run_program op直接跳过,不再进行后续的cast操作。因为run_program op只有FP32 kernel,且我们在动转静pure fp16训练时在调用run_program op时已经将run_program op的输入cast为fp16类型,所以在CastPureFp16Inputs应该跳过对于run_program op的处理,避免将其输入再cast回fp32类型。
  2. 动态图pure fp16训练loss 和 动转静pure fp16训练loss存在波动,所以将单测中的atol放大到1e-3,以确保单测通过。对于动转静pure fp16训练,已经在mnist和resnet网络上测试了网络的收敛性, 均可以正常收敛。
    • mnist部分训练过程
      image
    • resnet部分训练结果
      image
  3. 支持了动转静AMP以及动转静pure fp16训练中paddle.amp.auto_cast接口中黑白名单(custom_white_list参数和custom_black_list参数)的设置。在此PR之前在动转静AMP训练中设置这两个参数并不会生效。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Nov 2, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -118,6 +118,17 @@ def _in_amp_guard():
return False


def _in_pure_fp16_guard():
tracer = _dygraph_tracer()
if tracer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if tracer:
return tracer and tracer._amp_level == core.AmpLevel.O2

Aurelius84
Aurelius84 previously approved these changes Nov 22, 2021
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
with fluid.dygraph.guard(place):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认已经开启动态图了,不需要这个guard了

dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确保CI稳定没有问题,若放开tol,需要加NOTE

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit 52edad6 into PaddlePaddle:develop Nov 24, 2021
@0x45f 0x45f deleted the dy2stat_support_pure_fp16 branch November 24, 2021 03:40
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
* run dy2stat pure fp16 in Linear model

* no use self._pure_fp16_inputs

* add test and fix Adam error in dy2stat pure fp16 training

* use paddle.optimizer.Adam

* run test in gpu

* change test time for CI

* enlarge atol for test_resnet_pure_fp16

* refine code and enlarge atol

* make custom_white_list and custom_black_list take effect for AMP and pure fp16

* check tracer is not None

* use default atol

* change filter_size

* change atol and add some NOTE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants