From 273ead714f8890d77bc155767e0db61a11e8df7f Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 12 Apr 2023 12:44:16 +0800 Subject: [PATCH] Cherry-pick the support of bf16 of grad_clip, in #51285. --- python/paddle/fluid/clip.py | 64 +++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index d803de22606e16..2f38cd978889b6 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -420,6 +420,20 @@ def _allow_pure_fp16_global_norm_clip(*args): return old_value +_allow_pure_bf16_global_norm_clip_flag = False + + +def _allow_pure_bf16_global_norm_clip(*args): + global _allow_pure_bf16_global_norm_clip_flag + if len(args) == 0: + return _allow_pure_bf16_global_norm_clip_flag + else: + assert len(args) == 1 and isinstance(args[0], bool) + old_value = _allow_pure_bf16_global_norm_clip_flag + _allow_pure_bf16_global_norm_clip_flag = args[0] + return old_value + + class ClipGradByGlobalNorm(ClipGradBase): r""" Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in @@ -584,6 +598,7 @@ def _static_clip(self, params_grads): params_and_grads = [] sum_square_list = [] sum_square_list_fp16 = [] + sum_square_list_bf16 = [] sum_square_list_fp32 = [] with framework.name_scope('gradient_clip'): for p, g in params_grads: @@ -598,18 +613,27 @@ def _static_clip(self, params_grads): merge_grad = layers.get_tensor_from_selected_rows( merge_grad ) + sum_square = _squared_l2_norm(merge_grad) if sum_square.dtype == core.VarDesc.VarType.FP16: sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.BF16: + sum_square_list_bf16.append(sum_square) elif sum_square.dtype == core.VarDesc.VarType.FP32: sum_square_list_fp32.append(sum_square) else: sum_square_list.append(sum_square) + if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0: + raise NotSupportedError( + 'FP16 and BF16 are not supported at the same time.' + ) + # all parameters have been filterd out if ( len(sum_square_list) + len(sum_square_list_fp16) + + len(sum_square_list_bf16) + len(sum_square_list_fp32) == 0 ): @@ -620,7 +644,7 @@ def _static_clip(self, params_grads): global_norm_var = [] if len(sum_square_list_fp16) > 0: - global_norm_var_fp16 = layers.sums(sum_square_list_fp16) + global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16) if ( sum_square_list_fp32 or sum_square_list @@ -631,8 +655,20 @@ def _static_clip(self, params_grads): ) else: global_norm_var.append(global_norm_var_fp16) + if len(sum_square_list_bf16) > 0: + global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16) + if ( + sum_square_list_fp32 + or sum_square_list + or not _allow_pure_bf16_global_norm_clip() + ): + global_norm_var.append( + global_norm_var_bf16.astype(sum_dtype) + ) + else: + global_norm_var.append(global_norm_var_bf16) if len(sum_square_list_fp32) > 0: - global_norm_var_fp32 = layers.sums(sum_square_list_fp32) + global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) if sum_dtype == 'float32': global_norm_var.append(global_norm_var_fp32) else: @@ -641,23 +677,24 @@ def _static_clip(self, params_grads): ) if len(sum_square_list) > 0: # fp64 - global_norm_var_other_dtype = layers.sums(sum_square_list) + global_norm_var_other_dtype = paddle.add_n(sum_square_list) global_norm_var.append(global_norm_var_other_dtype) global_norm_var = ( - layers.sums(global_norm_var) + paddle.add_n(global_norm_var) if len(global_norm_var) > 1 else global_norm_var[0] ) - global_norm_var = layers.sqrt(x=global_norm_var) - max_global_norm = layers.fill_constant( - shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm + + global_norm_var = paddle.sqrt(x=global_norm_var) + max_global_norm = paddle.full( + shape=[1], + dtype=global_norm_var.dtype, + fill_value=self.clip_norm, ) - scale_var = layers.elementwise_div( + scale_var = paddle.divide( x=max_global_norm, - y=layers.elementwise_max( - x=max_global_norm, y=global_norm_var - ), + y=paddle.maximum(x=max_global_norm, y=global_norm_var), ) param_new_grad_name_dict = dict() for p, g in params_grads: @@ -671,9 +708,8 @@ def _static_clip(self, params_grads): new_g = _cast_to_mp_type_if_enabled(g) # inplace scale_input = ( - scale_var.astype('float16') - if new_g.dtype == core.VarDesc.VarType.FP16 - and scale_var.dtype != core.VarDesc.VarType.FP16 + scale_var.astype(new_g.dtype) + if scale_var.dtype != new_g.dtype else scale_var ) # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g