diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 3690c67ac58f4..52608af201d1e 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -439,10 +439,12 @@ void PirInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { "pd_op.c_softmax_with_cross_entropy", "pd_op.c_allgather", + "pd_op.c_allreduce_avg", "pd_op.c_allreduce_max", "pd_op.c_allreduce_min", "pd_op.c_allreduce_sum", "pd_op.c_allreduce_prod", + "pd_op.c_reduce_avg", "pd_op.c_reduce_max", "pd_op.c_reduce_min", "pd_op.c_reduce_prod", @@ -509,10 +511,12 @@ void PirInterpreter::UpdateNcclOpNum() { "pd_op.reduce_grad", "pd_op.c_softmax_with_cross_entropy_", "pd_op.c_allgather_", + "pd_op.c_allreduce_avg_", "pd_op.c_allreduce_max_", "pd_op.c_allreduce_min_", "pd_op.c_allreduce_sum_", "pd_op.c_allreduce_prod_", + "pd_op.c_reduce_avg_", "pd_op.c_reduce_max_", "pd_op.c_reduce_min_", "pd_op.c_reduce_prod_", diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc new file mode 100644 index 0000000000000..3343406a02b6c --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CAllReduceAvgOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +DECLARE_INPLACE_OP_INFERER(AllreduceAvgInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_avg, + ops::CAllReduceOp, + ops::CAllReduceAvgOpMaker, + ops::AllreduceAvgInplaceInferer) diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc new file mode 100644 index 0000000000000..d3f0b45f64432 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceAvg, kRedAvg) +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_allreduce_avg, + GPU, + ALL_LAYOUT, + ops::CAllReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 95e02e35adfc4..1fd4a8b73d43a 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -48,7 +48,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CAllReduceOp : public framework::OperatorWithKernel { public: @@ -413,6 +413,12 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_THROW(platform::errors::InvalidArgument( "Invalid reduce type: %d", red_type)); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cc new file mode 100644 index 0000000000000..53ce6e221a9f8 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +template +class EmptyGradOpMaker; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CReduceAvgOpMaker : public CReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_reduce_avg, + ops::CReduceOp, + ops::CReduceAvgOpMaker); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc new file mode 100644 index 0000000000000..07d2cc748900e --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_REDUCE_CUDA_KERNEL(CReduceAvg, kRedAvg); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_reduce_avg, + GPU, + ALL_LAYOUT, + ops::CReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index e8e240c9b5525..d90fb88fe8f3f 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -50,7 +50,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CReduceOp : public framework::OperatorWithKernel { public: @@ -304,6 +304,12 @@ class CReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_ENFORCE_EQ(true, false, diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 2cbcb29f705b3..c44748d283746 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -119,6 +119,8 @@ NO_NEED_GEN_STATIC_ONLY_APIS = [ 'add_n_', 'c_allgather', + 'c_allreduce_avg', + 'c_allreduce_avg_', 'c_allreduce_max', 'c_allreduce_min', 'c_allreduce_min_', @@ -158,6 +160,8 @@ 'soft_relu', 'uniform_random_batch_size_like', 'match_matrix_tensor', + 'c_reduce_avg', + 'c_reduce_avg_', 'c_reduce_max', 'c_reduce_max_', 'c_reduce_min', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index d856c58a75550..90b64f0395f79 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -138,6 +138,16 @@ kernel : func : c_allgather +- op : c_allreduce_avg + args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param : [x] + kernel : + func : c_allreduce_avg + inplace : (x -> out) + - op : c_allreduce_max args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -218,6 +228,16 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_avg + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_avg + inplace : (x -> out) + - op : c_reduce_max args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index c17a7fb6839cc..cca683ed0bbef 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -50,6 +50,8 @@ const std::unordered_set LegacyOpList = { CAllreduceProd_Op::name(), CAllreduceSumOp::name(), CAllreduceSum_Op::name(), + CAllreduceAvgOp::name(), + CAllreduceAvg_Op::name(), CReduceSumOp::name(), CReduceSum_Op::name(), CAllreduceMax_Op::name(), @@ -86,6 +88,8 @@ const std::unordered_set LegacyOpList = { paddle::onednn::dialect::MultiGruOp::name(), paddle::onednn::dialect::FusionLstmOp::name(), #endif + CReduceAvgOp::name(), + CReduceAvg_Op::name(), CReduceMaxOp::name(), CReduceMinOp::name(), CReduceProdOp::name(), diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 44a66c60e8078..fe232667f259d 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3513,6 +3513,12 @@ outputs : out: Out +- op: c_allreduce_avg + inputs : + x : X + outputs : + out: Out + - op: c_allreduce_max inputs : x : X @@ -3549,6 +3555,12 @@ outputs : out: Out +- op: c_reduce_avg + inputs : + x : X + outputs : + out: Out + - op: c_reduce_max inputs : x : X diff --git a/python/env_dict.py.in b/python/env_dict.py.in index 79e4e0704505a..a276adb00085e 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -1,9 +1,11 @@ env_dict={ + 'NCCL_VERSION':'@NCCL_VERSION@', 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', 'PADDLE_VERSION':'@PADDLE_VERSION@', 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', 'WITH_GPU':'@WITH_GPU@', + 'WITH_NCCL':'@WITH_NCCL@', 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index bcc64a50ae218..2fad0a278aeff 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -42,6 +42,7 @@ def set_field_default_config(category, field, default_value): BASE = "base" set_field_default_config(BASE, "auto_mode", "semi") set_field_default_config(BASE, "gradient_scale", True) +set_field_default_config(BASE, "gradient_scale_using_allreduce_avg", False) set_field_default_config(BASE, "use_cache", True) set_field_default_config(BASE, "return_numpy", True) set_field_default_config(BASE, "all_ranks", False) diff --git a/python/paddle/distributed/auto_parallel/static/dist_context.py b/python/paddle/distributed/auto_parallel/static/dist_context.py index eefc0d332957f..12d88ba779d3f 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_context.py +++ b/python/paddle/distributed/auto_parallel/static/dist_context.py @@ -127,6 +127,9 @@ def __init__( # flag whether scale gradient with dp size self._gradient_scale = True + # whether use allreduce_avg to scale gradient, i.e., allreduce_sum + scale -> allreduce_avg + self._gradient_scale_using_allreduce_avg = False + # A flag indicates whether the used parallelism is data parallel self._data_parallel = False @@ -220,6 +223,18 @@ def gradient_scale(self): def gradient_scale(self, gs): self._gradient_scale = gs + @property + def gradient_scale_using_allreduce_avg(self): + return self._gradient_scale_using_allreduce_avg + + @gradient_scale_using_allreduce_avg.setter + def gradient_scale_using_allreduce_avg( + self, gradient_scale_using_allreduce_avg + ): + self._gradient_scale_using_allreduce_avg = ( + gradient_scale_using_allreduce_avg + ) + @property def data_parallel(self): return self._data_parallel diff --git a/python/paddle/distributed/auto_parallel/static/dist_op.py b/python/paddle/distributed/auto_parallel/static/dist_op.py index b27e27ee98330..8d28c43eef4d7 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_op.py +++ b/python/paddle/distributed/auto_parallel/static/dist_op.py @@ -130,6 +130,8 @@ def __str__(self): f", process_mesh ({annotated_str}): {self.dist_attr.process_mesh}" ) + str += f" , execution_stream: {self.dist_attr.execution_stream}" + for arg_name in self.serial_op.desc.input_arg_names(): try: dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 401737bb13ac6..2215dc9475117 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -779,6 +779,11 @@ def _build(self, mode): self._json_config, ) self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale + self._dist_contexts[ + mode + ].gradient_scale_using_allreduce_avg = ( + self._strategy.gradient_scale_using_allreduce_avg + ) self._fwd_main_progs[mode] = serial_main_prog.clone() def _optimization_tuning(self, mode, dataset, batch_size): diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 9f95b049cce3c..c6de9955e08ea 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -503,6 +503,19 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): dist_op_context = dist_ctx.dist_op_context main_block = dist_op_context.work_block + allreduce_type = "c_allreduce_sum" + need_scale = dist_ctx.gradient_scale + scale_using_allreduce_avg = dist_ctx.gradient_scale_using_allreduce_avg + + # With nccl_version > 2.10.00, we can use c_allreduce_avg to replace c_allreduce_sum and eliminate the scale op. + if ( + need_scale + and scale_using_allreduce_avg + and int(paddle.version.nccl()) > 21000 + ): + allreduce_type = "c_allreduce_avg" + need_scale = False + for group in groups: group_size = len(group.ranks) @@ -510,7 +523,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): added_ops = [] grad_var = main_block.var(var_name) allreduce_op = main_block.append_op( - type='c_allreduce_sum', + type=allreduce_type, inputs={'X': [grad_var]}, outputs={'Out': [grad_var]}, attrs={ @@ -524,7 +537,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): ) added_ops.append(allreduce_op) - if dist_ctx.gradient_scale: + if need_scale: scale_op = main_block.append_op( type='scale', inputs={'X': grad_var}, @@ -654,7 +667,13 @@ def is_data_parallel_scale_op(op): def is_data_parallel_reduce_op(op): return ( - op.type in ["c_reduce_sum", "c_allreduce_sum"] + op.type + in [ + "c_allreduce_sum", + "c_allreduce_avg", + "c_reduce_sum", + "c_reduce_avg", + ] and op.desc.has_attr("op_namescope") and ParallelMode.DataParallel in op.desc.attr("op_namescope") ) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 16be4d0c7a43b..ec775f54b9fe1 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2193,12 +2193,13 @@ def insert_dependencies_for_vars( sync=False, op_namescope=None, use_nop=False, + skip_insert_when_sequential_run=True, ): """ dependency: op that generates prior_vars should be run before op that generates post_vars """ - if is_sequential_run(): + if skip_insert_when_sequential_run and is_sequential_run(): return if isinstance(prior_vars, Variable): diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index c820a3d882274..7db17c22b1453 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -440,7 +440,12 @@ def op_depend_on_group(op, group): def _update_program(self, grad_groups): block = default_main_program().global_block() - remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute'] + remove_op_types = [ + 'scale', + 'c_allreduce_avg', + 'c_allreduce_sum', + 'c_wait_compute', + ] for i, group in enumerate(grad_groups[::-1]): # skip unfused big tensor @@ -492,9 +497,10 @@ def _update_program(self, grad_groups): ) allreduce_op = block.ops[group.allreduce_op_idx] - assert ( - allreduce_op.type == 'c_allreduce_sum' - ), f"should found c_allreduce_sum op but found {str(allreduce_op)}" + assert allreduce_op.type in [ + 'c_allreduce_avg', + 'c_allreduce_sum', + ], f"should found c_allreduce_avg or c_allreduce_sum op but found {str(allreduce_op)}" allreduce_op_dist_attr = ( self.dist_context.get_op_dist_attr_for_program(allreduce_op) ) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 617425158dd89..8d1cf45eadaf9 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -32,8 +32,8 @@ is_backward_op, is_dep_skip_op, is_forward_op, - is_loss_grad_op, is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) @@ -544,11 +544,17 @@ def _shard_gradient_synchronization(self, main_block): dp_ring_ids = [group.id for group in self.dp_groups] for idx, op in reversed(list(enumerate(main_block.ops))): if _is_param_grad_allreduce_op(op, main_block): + reduce_op_type = ( + "c_reduce_sum" + if op.type in ["c_allreduce_sum", "c_reduce_sum"] + else "c_reduce_avg" + ) input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] reduce_op = _insert_reduce_op( main_block, + reduce_op_type, idx, input_name, sharding_info.group.id, @@ -933,7 +939,7 @@ def _fuse_overlap_parameter_comm_stage_two(self, sharding_info): sync=False, op_namescope="sharding_stage2_broadcast_dep", ) - if self.enable_overlap: + if self.enable_overlap and depend_op is not None: depend_op.dist_attr.execution_stream = comm_stream depend_op.dist_attr.scheduling_priority = ( self.comm_op_scheduling_priority @@ -979,8 +985,9 @@ def _group_grads( first_backward_op = None for op in ops: - if is_loss_grad_op(op): + if is_backward_op(op): first_backward_op = op + break # not backward op, sharding for inference if first_backward_op is None: return @@ -1000,9 +1007,10 @@ def op_depend_on_group(op, group): while i < len(ops): op = ops[i] if is_data_parallel_reduce_op(op): - assert ( - op.type == "c_reduce_sum" - ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + assert op.type in [ + "c_reduce_avg", + "c_reduce_sum", + ], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" grad_name = op.output_arg_names[0] param_name = _get_base_name_from_grad_name(grad_name) @@ -1035,9 +1043,10 @@ def op_depend_on_group(op, group): param_name ): cur_group.is_in_local_shard = True - assert ( - ops[i + 1].type == "c_allreduce_sum" - ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + assert ops[i + 1].type in [ + "c_allreduce_avg", + "c_allreduce_sum", + ], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" assert ( ops[i + 1].output_arg_names[0] == grad_name ), "Hybrid Sharding with Data-Parallel should sync same gradient var" @@ -1078,6 +1087,18 @@ def op_depend_on_group(op, group): persistable=False, stop_gradient=True, ) + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + group.vars[0] + ) + ) + set_var_dist_attr( + self._dist_context, + group.coalesce_var, + ref_dist_attr.dims_mapping, + ref_dist_attr.process_mesh, + chunk_id=ref_dist_attr.chunk_id, + ) coalesce_op_map[group.coalesce_op_idx] = group last_reduce_op_idx = group.reduce_op_indices.pop() modify_reduce_op_map[last_reduce_op_idx] = group @@ -1153,6 +1174,20 @@ def op_depend_on_group(op, group): OP_ROLE_KEY: OpRole.Backward, }, ) + + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + group.coalesce_var + ) + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + coalesce_op, + ref_dist_attr.process_mesh, + ref_dist_attr.dims_mapping, + self._dist_context, + chunk_id=ref_dist_attr.chunk_id, + ) + depend_op = insert_dependencies_for_vars( block, idx, @@ -1219,7 +1254,7 @@ def _overlap_grad_comm( grad_comm_op_to_stream_idx = {} for idx, op in enumerate(ops): if is_data_parallel_reduce_op(op): - if op.type == "c_allreduce_sum": + if op.type in ["c_allreduce_avg", "c_allreduce_sum"]: continue stream_idx = reduce_op_count % self.grad_comm_stream_num grad_comm_op_to_stream_idx[op] = stream_idx @@ -1245,6 +1280,8 @@ def _overlap_grad_comm( grad_group.vars[-1], grad_group.coalesce_var, comm_stream, + "sharding_grad_comm_dep", + op.dist_attr, ) ] # post dep @@ -1257,6 +1294,8 @@ def _overlap_grad_comm( grad_group.coalesce_var, grad_group.vars, comm_stream, + "sharding_grad_comm_dep", + op.dist_attr, ) ) @@ -1265,11 +1304,13 @@ def _overlap_grad_comm( op.dist_attr.scheduling_priority = ( self.comm_op_scheduling_priority ) - op._set_attr("ring_id", comm_group.id) if self.sharding_hybrid_dp and grad_group.is_in_local_shard: next_op = ops[idx + 1] - assert next_op.type == "c_allreduce_sum" + assert next_op.type in [ + "c_allreduce_avg", + "c_allreduce_sum", + ] assert next_op.output("Out")[0] == reduce_varname # FIXME hybrid sharding-dp support multi comm & stream in feature # next_op._set_attr("ring_id", comm_group.id) @@ -1279,6 +1320,34 @@ def _overlap_grad_comm( ) idx += 1 + # NOTE(Ruibiao): Why add dependecy here? + # It is hack to delay GC for coalesce_var, which significantly reduce memory usage. + # With the pattern of reduce_sum + scale, the coalesce_var is used by the reduce_sum + # op on the comm-stream, and then released by the scale op on the comp-stream. Since + # the generated and released op are both in comp-stream, the allocation of the + # coalesce_var can be fast-GC and reused by subsequent comp-op. However in reduce_avg + # parrent, the coalesce_var is released on the reduce_avg op in comm-stream, + # triggering a cross-stream GC. In such case, an event is recorded on the underlying + # allocation, and the memory is unable to reused by other comp-ops, resulting in an + # increase in memory usage. For more details, see the code of StreamSafeCUDAAllocator. + # This issue should be fixed using CUDAMallocAsyncAllocator in the future. + if ( + op.type == "c_reduce_avg" + and not grad_group.is_in_local_shard + ): + if idx not in dep_map: + dep_map[idx] = [] + dep_map[idx].append( + ( + idx + 1, + grad_group.coalesce_var, + grad_group.coalesce_var, + None, + "sharding_reduce_avg_dep", + op.dist_attr, + ) + ) + reduce_op_count += 1 idx += 1 @@ -1286,7 +1355,18 @@ def _overlap_grad_comm( # insert deps indice = sorted(dep_map.keys(), reverse=True) for i in indice: - for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]: + for ( + idx, + prior_vars, + post_vars, + comm_stream, + op_namescope, + dist_attr, + ) in dep_map[i][::-1]: + skip_insert_when_sequential_run = ( + False if op_namescope == "sharding_reduce_avg_dep" else True + ) + depend_op = insert_dependencies_for_vars( block, idx, @@ -1299,13 +1379,23 @@ def _overlap_grad_comm( ], # hack to avoid initialize the dist attr for coalesce var is_recompute=False, sync=False, - op_namescope="sharding_grad_comm_dep", - ) - depend_op.dist_attr.execution_stream = comm_stream - depend_op.dist_attr.scheduling_priority = ( - self.comm_op_scheduling_priority + op_namescope=op_namescope, + skip_insert_when_sequential_run=skip_insert_when_sequential_run, ) + if depend_op is not None: + naive_set_dist_op_attr_for_program_by_mesh( + depend_op, + process_mesh=dist_attr.process_mesh, + ctx=self._dist_context, + chunk_id=dist_attr.chunk_id, + ) + if comm_stream is not None: + depend_op.dist_attr.execution_stream = comm_stream + depend_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) + # hierarchical grad comm if self.enable_hierarchical_comm: # NOTE so far we only support Isomorphic cluster with 8 ranks per node @@ -1467,6 +1557,7 @@ def _insert_init_and_broadcast_op( def _insert_reduce_op( block, + op_type, insert_idx, reduce_var, ring_id, @@ -1480,7 +1571,7 @@ def _insert_reduce_op( ), f"root id should be a positive int, but now root id is {root_id}" new_op = block._insert_op_without_sync( insert_idx, - type='c_reduce_sum', + type=op_type, inputs={'X': [reduce_var]}, outputs={'Out': [reduce_var]}, attrs={ diff --git a/python/setup.py.in b/python/setup.py.in index 3ba1dc05e4976..98246fdbf4dc5 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -54,6 +54,11 @@ def get_major(): def get_minor(): return int(_get_version_detail(1)) +def get_nccl_version(): + if '@WITH_NCCL@' == 'ON': + return @NCCL_VERSION@ + return 0 + def get_patch(): return str(_get_version_detail(2)) @@ -119,6 +124,7 @@ full_version = '%(major)d.%(minor)d.%(patch)s' major = '%(major)d' minor = '%(minor)d' patch = '%(patch)s' +nccl_version = '%(nccl)d' rc = '%(rc)d' cuda_version = '%(cuda)s' cudnn_version = '%(cudnn)s' @@ -130,7 +136,7 @@ commit = '%(commit)s' with_mkl = '%(with_mkl)s' cinn_version = '%(cinn)s' -__all__ = ['cuda', 'cudnn', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] +__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] def show(): """Get the version of paddle if `paddle` package if tagged. Otherwise, output the corresponding commit id. @@ -205,6 +211,7 @@ def show(): print('commit:', commit) print('cuda:', cuda_version) print('cudnn:', cudnn_version) + print('nccl:', nccl_version) print('xpu:', xpu_version) print('xpu_xccl:', xpu_xccl_version) print('xpu_xhpc:', xpu_xhpc_version) @@ -213,6 +220,9 @@ def show(): def mkl(): return with_mkl +def nccl(): + return nccl_version + def cuda(): """Get cuda version of paddle package. @@ -336,6 +346,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), + 'nccl': get_nccl_version(), 'rc': RC, 'version': '${PADDLE_VERSION}', 'cuda': get_cuda_version(), diff --git a/setup.py b/setup.py index 2601cfe7b11b3..fd94bfa11accd 100644 --- a/setup.py +++ b/setup.py @@ -344,6 +344,12 @@ def get_patch(): return str(_get_version_detail(2)) +def get_nccl_version(): + if env_dict.get("WITH_NCCL") == 'ON': + return int(env_dict.get("NCCL_VERSION")) + return 0 + + def get_cuda_version(): with_gpu = env_dict.get("WITH_GPU") if with_gpu == 'ON': @@ -441,6 +447,7 @@ def write_version_py(filename='paddle/version/__init__.py'): major = '%(major)d' minor = '%(minor)d' patch = '%(patch)s' +nccl_version = '%(nccl)d' rc = '%(rc)d' cuda_version = '%(cuda)s' cudnn_version = '%(cudnn)s' @@ -452,7 +459,7 @@ def write_version_py(filename='paddle/version/__init__.py'): with_mkl = '%(with_mkl)s' cinn_version = '%(cinn)s' -__all__ = ['cuda', 'cudnn', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] +__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] def show(): """Get the version of paddle if `paddle` package if tagged. Otherwise, output the corresponding commit id. @@ -526,6 +533,7 @@ def show(): print('commit:', commit) print('cuda:', cuda_version) print('cudnn:', cudnn_version) + print('nccl:', nccl_version) print('xpu:', xpu_version) print('xpu_xccl:', xpu_xccl_version) print('xpu_xhpc:', xpu_xhpc_version) @@ -534,6 +542,9 @@ def show(): def mkl(): return with_mkl +def nccl(): + return nccl_version + def cuda(): """Get cuda version of paddle package. @@ -659,6 +670,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), + 'nccl': get_nccl_version(), 'rc': RC, 'version': env_dict.get("PADDLE_VERSION"), 'cuda': get_cuda_version(), diff --git a/test/auto_parallel/sharding_pass_unittest.py b/test/auto_parallel/sharding_pass_unittest.py index 82d17e821b7db..762fb6e239582 100644 --- a/test/auto_parallel/sharding_pass_unittest.py +++ b/test/auto_parallel/sharding_pass_unittest.py @@ -24,9 +24,10 @@ paddle.enable_static() -def apply_pass(use_sharding=False, stage=None): +def apply_pass(use_sharding=False, stage=None, use_allreduce_avg=False): strategy = auto.Strategy() strategy.auto_mode = "semi" + strategy.gradient_scale_using_allreduce_avg = use_allreduce_avg # strategy.reinit = True if use_sharding: sharding = strategy.sharding @@ -67,10 +68,12 @@ def init(self, engine): np.random.seed(2022) random.seed(2022) - def get_engine(self, use_sharding=False, stage=None): + def get_engine( + self, use_sharding=False, stage=None, use_allreduce_avg=False + ): reset_prog() - strategy = apply_pass(use_sharding, stage) + strategy = apply_pass(use_sharding, stage, use_allreduce_avg) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) # NOTE: setting opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) will cause precision problem opt = paddle.optimizer.AdamW(learning_rate=0.00001) @@ -150,6 +153,32 @@ def test_sharding_pass(self): sharding3_losses = np.array(history.history["loss"]) self.check_results(dp_losses, sharding3_losses) + # dp2 training using allreduce avg + dp_engine_using_allreduce_avg = self.get_engine(use_allreduce_avg=True) + dp_engine_using_allreduce_avg.prepare( + inputs_spec=input_spec, labels_spec=label_spec, mode='train' + ) + dp_engine_using_allreduce_avg.save( + "./dp_engine_using_allreduce_avg", training=True + ) + history = dp_engine_using_allreduce_avg.fit( + self.dataset, 3, batch_size=self.batch_size + ) + dp_losses_using_allreduce_avg = np.array(history.history["loss"]) + + # sharding2 stage2 training using allreduce avg + sharding2_engine_using_allreduce_avg = self.get_engine(True, 2, True) + sharding2_engine_using_allreduce_avg.load( + "./dp_engine_using_allreduce_avg" + ) + history = sharding2_engine_using_allreduce_avg.fit( + self.dataset, 3, batch_size=self.batch_size + ) + sharding2_losses_using_allreduce_avg = np.array(history.history["loss"]) + self.check_results( + dp_losses_using_allreduce_avg, sharding2_losses_using_allreduce_avg + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_dist_embedding.py b/test/auto_parallel/test_dist_embedding.py index f8dbd0fc9494d..7304b06aeb274 100644 --- a/test/auto_parallel/test_dist_embedding.py +++ b/test/auto_parallel/test_dist_embedding.py @@ -90,7 +90,7 @@ def test_lookup_table_v1_mp_dp(self): 'c_embedding_grad', 'c_allreduce_sum', 'scale', - ] + ], f"Unexpexted op types: {op_types}" if __name__ == "__main__":