diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index bcc64a50ae218..60a4bb324c8fc 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -104,6 +104,9 @@ def set_field_default_config(category, field, default_value): set_field_default_config(GRADIENT_MERGE, "enable", False) set_field_default_config(GRADIENT_MERGE, "k_steps", 1) set_field_default_config(GRADIENT_MERGE, "avg", True) +set_field_default_config( + GRADIENT_MERGE, "dp_gradient_sync_after_accumulate", False +) ######################################### # pipeline configuration diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 27a13fd1d9107..99a425614ff2a 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -416,6 +416,12 @@ def _apply_post_optimization( ) dp_pass.apply([main_program], [startup_program], self._pass_context) + dp_gradient_sync_after_accumulate = ( + self._strategy.gradient_merge.dp_gradient_sync_after_accumulate + ) + if dp_gradient_sync_after_accumulate: + global_params_grads = params_grads + if self._strategy.sharding.enable: config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context @@ -485,7 +491,10 @@ def _apply_post_optimization( if self.is_train and self._strategy.gradient_merge.enable: config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) config["dist_context"] = self._dist_context - config["params_grads"] = params_grads + if dp_gradient_sync_after_accumulate: + config["params_grads"] = global_params_grads + else: + config["params_grads"] = params_grads auto_parallel_gradient_merge_pass = new_pass( "auto_parallel_gradient_merge_pass", config ) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index ab41c2100982a..f5298782fc3ce 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -16,6 +16,10 @@ import paddle from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.static.operators.common import ( + is_data_parallel_reduce_op, + is_data_parallel_scale_op, +) from paddle.distributed.auto_parallel.static.process_group import ( get_world_process_group, ) @@ -260,6 +264,51 @@ def _append_gradient_merge_backward_op( return new_params_grads, grad_to_gradient_merge +def _move_reduce_to_optimizer_ops_block( + main_program, optimize_ops_block, params_grads +): + main_block = main_program.global_block() + removed_op_idx = [] + params_grads_name = [grad.name for _, grad in params_grads] + + for idx, op in list(enumerate(main_block.ops)): + if is_data_parallel_reduce_op(op): + op_input_names = op.desc.input_arg_names() + # NOTE(sonder): When "@RENAME@" is in the input name, it means that the op has been renamed. + # Such types input names are caused by shared parameter policy. + # Gradient merge should accumulate the gradient of ops without renaming. + if "@RENAME" in op_input_names[0]: + continue + + reduce_op_desc = optimize_ops_block.desc._insert_op( + len(removed_op_idx) + ) + reduce_op_desc.copy_from(op.desc) + reduce_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize) + removed_op_idx.append(idx) + + if op.type in ["c_allreduce_sum", "c_reduce_sum"]: + scale_index = idx + 1 + while scale_index < len(main_block.ops): + if is_data_parallel_scale_op(main_block.ops[scale_index]): + scale_op_desc = optimize_ops_block.desc._insert_op( + len(removed_op_idx) + ) + scale_op_desc.copy_from( + main_block.ops[scale_index].desc + ) + scale_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize) + removed_op_idx.append(scale_index) + break + scale_index += 1 + + for idx in removed_op_idx[::-1]: + main_block._remove_op(idx, sync=False) + + main_block._sync_with_cpp() + return optimize_ops_block + + def _create_cond_block_and_update_optimizer( main_program, cond_var, @@ -390,7 +439,13 @@ def true_apply_gradient(): def parse_program( - main_program, startup_program, params_grads, k_steps, avg, dist_context + main_program, + startup_program, + params_grads, + k_steps, + avg, + dist_context, + dp_gradient_sync_after_accumulate, ): # 1 remove optimizer_op from main_program optimize_ops_block = _remove_and_get_optimizer_op( @@ -405,10 +460,16 @@ def parse_program( main_program, startup_program, params_grads, dist_context ) - # 3 create gradient_merge_cond + if dp_gradient_sync_after_accumulate: + # 3 move reduce op to optimizer_ops_block + optimize_ops_block = _move_reduce_to_optimizer_ops_block( + main_program, optimize_ops_block, params_grads + ) + + # 4 create gradient_merge_cond cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) - # 4 create ConditionalBlock and append gradient merge optimizer ops + # 5 create ConditionalBlock and append gradient merge optimizer ops _create_cond_block_and_update_optimizer( main_program, cond_var, @@ -444,6 +505,9 @@ def _apply_single_impl(self, main_program, startup_program, context): avg = self.get_attr("avg", False) dist_context = self.get_attr("dist_context") params_grads = self.get_attr("params_grads") + dp_gradient_sync_after_accumulate = self.get_attr( + "dp_gradient_sync_after_accumulate", False + ) with paddle.static.program_guard(main_program, startup_program): parse_program( main_program, @@ -452,6 +516,7 @@ def _apply_single_impl(self, main_program, startup_program, context): k_steps, avg, dist_context, + dp_gradient_sync_after_accumulate, ) main_program._sync_with_cpp()