Skip to content

Commit

Permalink
fix a bug of stage2 offload. (#49767)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Jan 13, 2023
1 parent d58cca9 commit 1c8531c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def __init__(
self._rank = self._group.rank
self._global_root_rank = self._group.ranks[0]

if self._dp_group is not None and self._dp_group.nranks > 1:
assert (
not offload
), "Not support! when using offload with sharding stage2, please use pure sharding stage2, exclude data parallel."

# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
Expand All @@ -164,6 +169,7 @@ def __init__(
if (
hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
and not offload
):
self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,29 @@
paddle.seed(seed)


def train_mlp(model, offload=False):
def train_mlp(model, offload=False, test=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True)

model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = GroupShardedScaler(scaler)

dp_group = (
None
if not test
else paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)
)
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, offload=offload
params=optimizer._parameter_list,
optim=optimizer,
offload=offload,
dp_group=dp_group,
)
model = GroupShardedStage2(
model, optimizer, buffer_max_size=2**21, dp_group=dp_group
)
model = GroupShardedStage2(model, optimizer, buffer_max_size=2**21)

paddle.seed(2023)
np.random.seed(2023)
Expand Down Expand Up @@ -103,6 +115,13 @@ def test_sharding_stage2_offload():
rtol=5e-3,
atol=5e-3,
)

# just to test assert error for the rate of coverage
try:
train_mlp(mlp_offload, offload=True, test=True)
except Exception as e:
assert isinstance(e, AssertionError)

return


Expand Down

0 comments on commit 1c8531c

Please sign in to comment.