-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2stat]support pure fp16 for dy2stat #36944
Merged
Aurelius84
merged 15 commits into
PaddlePaddle:develop
from
0x45f:dy2stat_support_pure_fp16
Nov 24, 2021
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
3be2ba2
run dy2stat pure fp16 in Linear model
3e47842
no use self._pure_fp16_inputs
c047452
Merge branch 'develop' into dy2stat_support_pure_fp16
768f134
add test and fix Adam error in dy2stat pure fp16 training
c783c96
use paddle.optimizer.Adam
540c92b
run test in gpu
8271bdb
change test time for CI
fe1f54e
enlarge atol for test_resnet_pure_fp16
526a52b
Merge branch 'develop' into dy2stat_support_pure_fp16
2e4a8e9
refine code and enlarge atol
aeee06e
make custom_white_list and custom_black_list take effect for AMP and …
22950b7
check tracer is not None
edce42a
use default atol
0b46d22
change filter_size
aab57db
change atol and add some NOTE
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
python/paddle/fluid/tests/unittests/dygraph_to_static/test_mnist_pure_fp16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import unittest | ||
import numpy as np | ||
from time import time | ||
from test_mnist import MNIST, TestMNIST, SEED, SimpleImgConvPool | ||
from paddle.jit import ProgramTranslator | ||
from paddle.fluid.optimizer import AdamOptimizer | ||
|
||
if paddle.fluid.is_compiled_with_cuda(): | ||
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True}) | ||
|
||
|
||
class TestPureFP16(TestMNIST): | ||
def train_static(self): | ||
return self.train(to_static=True) | ||
|
||
def train_dygraph(self): | ||
return self.train(to_static=False) | ||
|
||
def test_mnist_to_static(self): | ||
if paddle.fluid.is_compiled_with_cuda(): | ||
dygraph_loss = self.train_dygraph() | ||
static_loss = self.train_static() | ||
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here. | ||
self.assertTrue( | ||
np.allclose( | ||
dygraph_loss, static_loss, atol=1e-3), | ||
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss, | ||
static_loss)) | ||
|
||
def train(self, to_static=False): | ||
np.random.seed(SEED) | ||
paddle.seed(SEED) | ||
paddle.framework.random._manual_program_seed(SEED) | ||
|
||
mnist = MNIST() | ||
|
||
if to_static: | ||
print("Successfully to apply @to_static.") | ||
mnist = paddle.jit.to_static(mnist) | ||
|
||
optimizer = paddle.optimizer.Adam( | ||
learning_rate=0.001, parameters=mnist.parameters()) | ||
|
||
scaler = paddle.amp.GradScaler(init_loss_scaling=1024) | ||
|
||
mnist, optimizer = paddle.amp.decorate( | ||
models=mnist, | ||
optimizers=optimizer, | ||
level='O2', | ||
save_dtype='float32') | ||
|
||
loss_data = [] | ||
for epoch in range(self.epoch_num): | ||
start = time() | ||
for batch_id, data in enumerate(self.train_reader()): | ||
dy_x_data = np.array([x[0].reshape(1, 28, 28) | ||
for x in data]).astype('float32') | ||
y_data = np.array( | ||
[x[1] for x in data]).astype('int64').reshape(-1, 1) | ||
|
||
img = paddle.to_tensor(dy_x_data) | ||
label = paddle.to_tensor(y_data) | ||
label.stop_gradient = True | ||
|
||
with paddle.amp.auto_cast( | ||
enable=True, | ||
custom_white_list=None, | ||
custom_black_list=None, | ||
level='O2'): | ||
prediction, acc, avg_loss = mnist(img, label=label) | ||
|
||
scaled = scaler.scale(avg_loss) | ||
scaled.backward() | ||
scaler.minimize(optimizer, scaled) | ||
|
||
loss_data.append(avg_loss.numpy()[0]) | ||
# save checkpoint | ||
mnist.clear_gradients() | ||
if batch_id % 10 == 0: | ||
print( | ||
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}" | ||
.format(epoch, batch_id, | ||
avg_loss.numpy(), acc.numpy(), time() - start)) | ||
start = time() | ||
if batch_id == 50: | ||
break | ||
return loss_data | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
124 changes: 124 additions & 0 deletions
124
python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import math | ||
import time | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
import paddle.fluid as fluid | ||
from paddle.fluid.dygraph import declarative, ProgramTranslator | ||
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D | ||
from test_resnet import ResNet, optimizer_setting, SEED | ||
|
||
# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout. | ||
batch_size = 2 | ||
epoch_num = 1 | ||
|
||
program_translator = ProgramTranslator() | ||
|
||
if fluid.is_compiled_with_cuda(): | ||
fluid.set_flags({'FLAGS_cudnn_deterministic': True}) | ||
|
||
|
||
def train(to_static, build_strategy=None): | ||
""" | ||
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode. | ||
""" | ||
np.random.seed(SEED) | ||
paddle.seed(SEED) | ||
paddle.framework.random._manual_program_seed(SEED) | ||
|
||
resnet = ResNet() | ||
if to_static: | ||
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy) | ||
optimizer = optimizer_setting(parameter_list=resnet.parameters()) | ||
scaler = paddle.amp.GradScaler(init_loss_scaling=1024) | ||
|
||
resnet, optimizer = paddle.amp.decorate( | ||
models=resnet, optimizers=optimizer, level='O2', save_dtype='float32') | ||
|
||
for epoch in range(epoch_num): | ||
loss_data = [] | ||
total_loss = 0.0 | ||
total_acc1 = 0.0 | ||
total_acc5 = 0.0 | ||
total_sample = 0 | ||
|
||
for batch_id in range(100): | ||
start_time = time.time() | ||
img = paddle.to_tensor( | ||
np.random.random([batch_size, 3, 224, 224]).astype('float32')) | ||
label = paddle.to_tensor( | ||
np.random.randint( | ||
0, 100, [batch_size, 1], dtype='int64')) | ||
img.stop_gradient = True | ||
label.stop_gradient = True | ||
|
||
with paddle.amp.auto_cast( | ||
enable=True, | ||
custom_white_list=None, | ||
custom_black_list=None, | ||
level='O2'): | ||
pred = resnet(img) | ||
loss = fluid.layers.cross_entropy(input=pred, label=label) | ||
avg_loss = fluid.layers.mean(x=pred) | ||
acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1) | ||
acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5) | ||
|
||
scaled = scaler.scale(avg_loss) | ||
scaled.backward() | ||
scaler.minimize(optimizer, scaled) | ||
resnet.clear_gradients() | ||
|
||
loss_data.append(avg_loss.numpy()[0]) | ||
total_loss += avg_loss | ||
total_acc1 += acc_top1 | ||
total_acc5 += acc_top5 | ||
total_sample += 1 | ||
|
||
end_time = time.time() | ||
if batch_id % 2 == 0: | ||
print( "epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f" % \ | ||
( epoch, batch_id, total_loss.numpy() / total_sample, \ | ||
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time)) | ||
if batch_id == 10: | ||
break | ||
|
||
return loss_data | ||
|
||
|
||
class TestResnet(unittest.TestCase): | ||
def train(self, to_static): | ||
program_translator.enable(to_static) | ||
return train(to_static) | ||
|
||
def test_resnet(self): | ||
if fluid.is_compiled_with_cuda(): | ||
static_loss = self.train(to_static=True) | ||
dygraph_loss = self.train(to_static=False) | ||
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here. | ||
self.assertTrue( | ||
np.allclose( | ||
static_loss, dygraph_loss, atol=1e-3), | ||
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss, | ||
dygraph_loss)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确保CI稳定没有问题,若放开tol,需要加NOTE