-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support offload in sharding stage2 #37904
Merged
ForFishes
merged 11 commits into
PaddlePaddle:develop
from
haohongxiang:support_offload_for_sharding_stage2
Dec 9, 2021
Merged
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
ca5b893
merge latest develop branch
haohongxiang a5837e1
fix bugs
haohongxiang 8ac6f25
update
haohongxiang 881aedf
fix bugs for unittest
haohongxiang 7e0a6c6
modify for less use of gpu mem
haohongxiang 697e458
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
haohongxiang 0a21e8d
fix bugs of using _reset_grad_inplace_version
haohongxiang b771b29
update
haohongxiang f37411d
update
haohongxiang 89a75af
modify for CI-Coverage
haohongxiang 3c6aa89
retrick all CIs
haohongxiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,11 +27,13 @@ | |
import paddle | ||
import paddle.fluid as fluid | ||
from paddle import framework | ||
from paddle.fluid import core | ||
import paddle.distributed as dist | ||
from paddle.optimizer import Optimizer | ||
from paddle.fluid.clip import ClipGradByGlobalNorm | ||
|
||
from ...utils.internal_storage import ParamStorage | ||
from ...meta_parallel.sharding.sharding_utils import Type | ||
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad | ||
|
||
# CUDA alignment 256 bytes | ||
alignment = {"gpu": 256, } | ||
|
@@ -50,7 +52,7 @@ class ShardingOptimizerStage2(Optimizer): | |
.. warning: ShardingOptimizer encapsulates the optimization strategy and integrates it into the optimizer. | ||
|
||
.. ZeRO: 1.https://arxiv.org/pdf/1910.02054.pdf 2.https://arxiv.org/pdf/1910.02054.pdf. | ||
|
||
""" | ||
|
||
# TODO (Baibaifan) | ||
|
@@ -99,16 +101,41 @@ def __init__(self, | |
|
||
self.broadcast_fp16 = broadcast_fp16 | ||
self.param_storages = {} # {dtype: {rank: InternalStorage}} | ||
|
||
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm): | ||
logging.warning( | ||
"While using ClipGradByGlobalNorm in ShardingOptimizer, the grad clip of original optimizer will be changed." | ||
) | ||
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip, | ||
group, | ||
paddle.get_device()) | ||
|
||
if offload: | ||
assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16" | ||
|
||
self.offload = offload # Using for offload | ||
self.offload_device = "cpu" | ||
|
||
self._master_params = {} | ||
|
||
# Update optimizer parameters and adjust parameter storage and use according to rank. | ||
self.update_opt_status() | ||
|
||
def _generate_master_params(self, trainable_params): | ||
for param in trainable_params: | ||
if param.dtype == Type.fp16.value: | ||
self._optim._master_weights[param.name] = paddle.cast( | ||
param, Type.fp32.value) | ||
if self.offload: | ||
for param in trainable_params: | ||
if param.name not in self._master_params.keys(): | ||
self._master_params[param.name] = core.VarBase( | ||
name=param.name, | ||
value=param.cast(dtype=Type.fp32.value).numpy(), | ||
place=core.CPUPlace(), | ||
stop_gradient=param.stop_gradient) | ||
self._optim._master_weights = self._master_params | ||
else: | ||
for param in trainable_params: | ||
if param.dtype == Type.fp16.value: | ||
self._optim._master_weights[param.name] = paddle.cast( | ||
param, Type.fp32.value) | ||
|
||
def update_opt_status(self): | ||
"""Update optimizer status and parameter storage information, and special functions to be developed. | ||
|
@@ -243,22 +270,43 @@ def step(self): | |
A wrapper for Optimizer's step function to finish the update operation of the optimizer. | ||
""" | ||
|
||
# Synchronize optimizer parameters for the current rank | ||
if len(self.dtype_rank_params.keys( | ||
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp32.value][self.rank] | ||
elif len(self.dtype_rank_params.keys( | ||
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp16.value][self.rank] | ||
if self.offload: | ||
self._optim._parameter_list = [ | ||
param for name, param in self._master_params.items() | ||
] | ||
else: | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp16.value][self.rank] + self.dtype_rank_params[ | ||
# Synchronize optimizer parameters for the current rank | ||
if len(self.dtype_rank_params.keys( | ||
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp32.value][self.rank] | ||
elif len(self.dtype_rank_params.keys( | ||
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp16.value][self.rank] | ||
else: | ||
self._optim._parameter_list = self.dtype_rank_params[ | ||
Type.fp16.value][self.rank] + self.dtype_rank_params[ | ||
Type.fp32.value][self.rank] | ||
|
||
# Run the optimizer of the current rank step | ||
self._optim.step() | ||
if self.offload: | ||
with device_guard(self.rank, self.offload_device): | ||
self._optim.step() | ||
|
||
for param in self._optim._parameter_list: | ||
self._master_params[param.name].set_value(param) | ||
|
||
dev_id = 0 if paddle.get_device() == "cpu" else int( | ||
paddle.get_device().split(":")[1]) | ||
|
||
for param in self._local_params: | ||
if param.name in self._master_params.keys(): | ||
param.set_value(self._master_params[param.name].cuda(dev_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方会增加显存,需要先释放param,在shareddata master参数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
.cast(dtype=param.dtype)) | ||
self._master_params[param.name].clear_gradient(False) | ||
else: | ||
self._optim.step() | ||
|
||
# Synchronize all the updated shards in between the ranks | ||
self._broadcast_params() | ||
|
@@ -286,8 +334,8 @@ def _broadcast_params(self): | |
group=self.group, | ||
use_calc_stream=True) | ||
|
||
# Multi stream operation will be supported later | ||
dist.wait( | ||
tensor=internal_storage.buffer, | ||
group=self.group, | ||
use_calc_stream=True) | ||
# Multi stream operation will be supported later | ||
dist.wait( | ||
tensor=internal_storage.buffer, | ||
group=self.group, | ||
use_calc_stream=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也改成.value().get_tensor()吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的