Skip to content

Commit

Permalink
[Dygraph] Support param groups in grad_clip (#39175)
Browse files Browse the repository at this point in the history
* support param groups in grad_clip

* update

* modify for review
  • Loading branch information
haohongxiang authored Jan 25, 2022
1 parent faf517b commit b0cca48
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def __init__(self, clip, hcg):

@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []

sum_square_dist_fp16 = []
sum_square_dist_fp32 = []
sum_square_not_dist_fp16 = []
Expand Down Expand Up @@ -153,15 +151,14 @@ def _dygraph_clip(self, params_grads):
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
g.scale_(clip_var_fp16)
else:
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
g.scale_(clip_var)
p._reset_grad_inplace_version(True)

return params_and_grads
return params_grads

def __getattr__(self, item):
return getattr(self._clip, item)
Expand Down Expand Up @@ -201,6 +198,12 @@ def __init__(self, optimizer, hcg, strategy):
else:
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
if self._inner_opt._parameter_list and isinstance(
self._inner_opt._parameter_list[0], dict):
for item in self._inner_opt._param_groups:
if "grad_clip" in item.keys():
item["grad_clip"] = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)

@imperative_base.no_grad
@framework.dygraph_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def __init__(self,
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
paddle.get_device(),
self.group)
if self._optim._parameter_list and isinstance(
self._optim._parameter_list[0], dict):
for item in self._optim._param_groups:
if "grad_clip" in item.keys():
item["grad_clip"] = ShardingClipGrad(
self._optim._grad_clip,
paddle.get_device(), self.group)

if offload:
assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def __init__(self, clip, device, group):

@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []

sum_square_fp16 = []
sum_square_fp32 = []

Expand Down Expand Up @@ -114,15 +112,14 @@ def _dygraph_clip(self, params_grads):
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
g.scale_(clip_var_fp16)
else:
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
g.scale_(clip_var)
p._reset_grad_inplace_version(True)

return params_and_grads
return params_grads

def __getattr__(self, item):
return getattr(self._clip, item)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,13 @@ def test_dp_stage2():
mlp2 = MLP()
mlp3 = MLP()
mlp4 = MLP()
mlp5 = MLP()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)

dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp(
Expand All @@ -181,6 +184,11 @@ def test_dp_stage2():
rtol=1e-5,
atol=1e-5)

stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
return


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train_mlp(model, offload=False):
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, offload=offload)
model = ShardingStage2(
model, optimizer, buffer_max_size=2**21, accumulate_grads=True)
model, optimizer, buffer_max_size=2**21, accumulate_grads=False)

train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
Expand Down Expand Up @@ -98,12 +98,11 @@ def test_sharding_stage2_offload():
mlp_offload_params = train_mlp(mlp_offload, offload=True)

for i in range(len(mlp_params)):
for j in range(len(mlp_offload_params)):
if mlp_params[i].name == mlp_offload_params[j].name:
np.testing.assert_allclose(
mlp_params[i].numpy(),
mlp_offload_params[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
mlp_params[i].numpy(),
mlp_offload_params[i].numpy(),
rtol=5e-3,
atol=5e-3)
return


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,19 @@ def build_optimizer(self, model):
return scheduler, optimizer


class TestPPClipGradParamGroup(TestDistPPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.Momentum(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=[{
"params": model.parameters()
}])
return scheduler, optimizer


if __name__ == "__main__":
unittest.main()

0 comments on commit b0cca48

Please sign in to comment.