Skip to content

Commit

Permalink
Cherry-pick the support of bf16 of grad_clip, in PaddlePaddle#51285.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Apr 12, 2023
1 parent 3869a3b commit 273ead7
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions python/paddle/fluid/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value


_allow_pure_bf16_global_norm_clip_flag = False


def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value


class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
Expand Down Expand Up @@ -584,6 +598,7 @@ def _static_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
Expand All @@ -598,18 +613,27 @@ def _static_clip(self, params_grads):
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad
)

sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)

if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)

# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32)
== 0
):
Expand All @@ -620,7 +644,7 @@ def _static_clip(self, params_grads):

global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
if (
sum_square_list_fp32
or sum_square_list
Expand All @@ -631,8 +655,20 @@ def _static_clip(self, params_grads):
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
Expand All @@ -641,23 +677,24 @@ def _static_clip(self, params_grads):
)
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var_other_dtype = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)

global_norm_var = (
layers.sums(global_norm_var)
paddle.add_n(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm

global_norm_var = paddle.sqrt(x=global_norm_var)
max_global_norm = paddle.full(
shape=[1],
dtype=global_norm_var.dtype,
fill_value=self.clip_norm,
)
scale_var = layers.elementwise_div(
scale_var = paddle.divide(
x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var
),
y=paddle.maximum(x=max_global_norm, y=global_norm_var),
)
param_new_grad_name_dict = dict()
for p, g in params_grads:
Expand All @@ -671,9 +708,8 @@ def _static_clip(self, params_grads):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
scale_var.astype(new_g.dtype)
if scale_var.dtype != new_g.dtype
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
Expand Down

0 comments on commit 273ead7

Please sign in to comment.