From e48ff62e49fcd76ce5f95396c8ca19c2a11d5990 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 5 Feb 2024 16:50:17 +0800 Subject: [PATCH 01/18] Using allreduce_avg to eliminate scale in auto parallel DP --- .../collective/c_allreduce_avg_op.cc | 45 +++++++++++++++++++ .../collective/c_allreduce_avg_op.cu.cc | 35 +++++++++++++++ .../operators/collective/c_allreduce_op.h | 8 +++- .../operators/collective/c_reduce_avg_op.cc | 44 ++++++++++++++++++ .../collective/c_reduce_avg_op.cu.cc | 35 +++++++++++++++ .../fluid/operators/collective/c_reduce_op.h | 8 +++- python/env_dict.py.in | 1 + .../auto_parallel/static/operators/common.py | 21 +++++++-- .../passes/auto_parallel_sharding.py | 11 ++++- python/setup.py.in | 13 +++++- setup.py | 8 +++- 11 files changed, 220 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_allreduce_avg_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_avg_op.cc create mode 100644 paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc new file mode 100644 index 0000000000000..3343406a02b6c --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CAllReduceAvgOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +DECLARE_INPLACE_OP_INFERER(AllreduceAvgInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_avg, + ops::CAllReduceOp, + ops::CAllReduceAvgOpMaker, + ops::AllreduceAvgInplaceInferer) diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc new file mode 100644 index 0000000000000..d3f0b45f64432 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceAvg, kRedAvg) +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_allreduce_avg, + GPU, + ALL_LAYOUT, + ops::CAllReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 9cd472f421788..e0386038b4fc7 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -48,7 +48,7 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CAllReduceOp : public framework::OperatorWithKernel { public: @@ -413,6 +413,12 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_THROW(platform::errors::InvalidArgument( "Invalid reduce type: %d", red_type)); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cc new file mode 100644 index 0000000000000..53ce6e221a9f8 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +template +class EmptyGradOpMaker; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CReduceAvgOpMaker : public CReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_reduce_avg, + ops::CReduceOp, + ops::CReduceAvgOpMaker); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc new file mode 100644 index 0000000000000..07d2cc748900e --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_REDUCE_CUDA_KERNEL(CReduceAvg, kRedAvg); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_reduce_avg, + GPU, + ALL_LAYOUT, + ops::CReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 20884d1ae8a96..c760953692edb 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -50,7 +50,7 @@ PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CReduceOp : public framework::OperatorWithKernel { public: @@ -304,6 +304,12 @@ class CReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_ENFORCE_EQ(true, false, diff --git a/python/env_dict.py.in b/python/env_dict.py.in index ca15773766180..bd9617b9b7c82 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -1,4 +1,5 @@ env_dict={ + 'NCCL_VERSION':'@NCCL_VERSION@', 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', 'PADDLE_VERSION':'@PADDLE_VERSION@', 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 75a45a510b0ca..5bf5c7738779e 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -503,6 +503,15 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): dist_op_context = dist_ctx.dist_op_context main_block = dist_op_context.work_block + allreduce_type = "c_allreduce_sum" + need_scale = dist_ctx.gradient_scale + + # With nccl_version > 2.10.00, we can use c_allreduce_avg to replace c_allreduce_sum and eliminate the scale op. + print(f"nccl_version: {paddle.version.nccl()}") + if need_scale and paddle.version.nccl() > 21000: + allreduce_type = "c_allreduce_avg" + need_scale = False + for group in groups: group_size = len(group.ranks) @@ -510,7 +519,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): added_ops = [] grad_var = main_block.var(var_name) allreduce_op = main_block.append_op( - type='c_allreduce_sum', + type=allreduce_type, inputs={'X': [grad_var]}, outputs={'Out': [grad_var]}, attrs={ @@ -524,7 +533,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): ) added_ops.append(allreduce_op) - if dist_ctx.gradient_scale: + if need_scale: scale_op = main_block.append_op( type='scale', inputs={'X': grad_var}, @@ -654,7 +663,13 @@ def is_data_parallel_scale_op(op): def is_data_parallel_reduce_op(op): return ( - op.type in ["c_reduce_sum", "c_allreduce_sum"] + op.type + in [ + "c_allreduce_sum", + "c_allreduce_avg", + "c_reduce_sum", + "c_reduce_avg", + ] and op.desc.has_attr("op_namescope") and ParallelMode.DataParallel in op.desc.attr("op_namescope") ) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 2f983dc2d06e9..60c2988164384 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -32,7 +32,6 @@ is_backward_op, is_dep_skip_op, is_forward_op, - is_loss_grad_op, is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, @@ -544,11 +543,17 @@ def _shard_gradient_synchronization(self, main_block): dp_ring_ids = [group.id for group in self.dp_groups] for idx, op in reversed(list(enumerate(main_block.ops))): if _is_param_grad_allreduce_op(op, main_block): + reduce_op_type = ( + "c_reduce_sum" + if op.type in ["c_allreduce_sum", "c_reduce_sum"] + else "c_reduce_avg" + ) input_name = op.input_arg_names[0] base_name = _get_base_name_from_grad_name(input_name) sharding_info = self.varname_to_sharding_info[base_name] reduce_op = _insert_reduce_op( main_block, + reduce_op_type, idx, input_name, sharding_info.group.id, @@ -979,8 +984,9 @@ def _group_grads( first_backward_op = None for op in ops: - if is_loss_grad_op(op): + if is_backward_op(op): first_backward_op = op + break # not backward op, sharding for inference if first_backward_op is None: return @@ -1467,6 +1473,7 @@ def _insert_init_and_broadcast_op( def _insert_reduce_op( block, + op_type, insert_idx, reduce_var, ring_id, diff --git a/python/setup.py.in b/python/setup.py.in index 37cbb638e4aab..70b8bacc62188 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -54,6 +54,11 @@ def get_major(): def get_minor(): return int(_get_version_detail(1)) +def get_nccl_version(): + if '@WITH_NCCL@' == 'ON': + return '@NCCL_VERSION@' + return 0 + def get_patch(): return str(_get_version_detail(2)) @@ -119,6 +124,7 @@ full_version = '%(major)d.%(minor)d.%(patch)s' major = '%(major)d' minor = '%(minor)d' patch = '%(patch)s' +nccl_version = '%(nccl)d' rc = '%(rc)d' cuda_version = '%(cuda)s' cudnn_version = '%(cudnn)s' @@ -130,7 +136,7 @@ commit = '%(commit)s' with_mkl = '%(with_mkl)s' cinn_version = '%(cinn)s' -__all__ = ['cuda', 'cudnn', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] +__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] def show(): """Get the version of paddle if `paddle` package if tagged. Otherwise, output the corresponding commit id. @@ -205,6 +211,7 @@ def show(): print('commit:', commit) print('cuda:', cuda_version) print('cudnn:', cudnn_version) + print('nccl:', nccl_version) print('xpu:', xpu_version) print('xpu_xccl:', xpu_xccl_version) print('xpu_xhpc:', xpu_xhpc_version) @@ -213,6 +220,9 @@ def show(): def mkl(): return with_mkl +def nccl(): + return nccl_version + def cuda(): """Get cuda version of paddle package. @@ -336,6 +346,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), + 'nccl': get_nccl_version(), 'rc': RC, 'version': '${PADDLE_VERSION}', 'cuda': get_cuda_version(), diff --git a/setup.py b/setup.py index 67774e58138b8..fb5a4e8dc17ac 100644 --- a/setup.py +++ b/setup.py @@ -441,6 +441,7 @@ def write_version_py(filename='paddle/version/__init__.py'): major = '%(major)d' minor = '%(minor)d' patch = '%(patch)s' +nccl_version = '%(nccl)d' rc = '%(rc)d' cuda_version = '%(cuda)s' cudnn_version = '%(cudnn)s' @@ -452,7 +453,7 @@ def write_version_py(filename='paddle/version/__init__.py'): with_mkl = '%(with_mkl)s' cinn_version = '%(cinn)s' -__all__ = ['cuda', 'cudnn', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] +__all__ = ['cuda', 'cudnn', 'nccl', 'show', 'xpu', 'xpu_xccl', 'xpu_xhpc'] def show(): """Get the version of paddle if `paddle` package if tagged. Otherwise, output the corresponding commit id. @@ -526,6 +527,7 @@ def show(): print('commit:', commit) print('cuda:', cuda_version) print('cudnn:', cudnn_version) + print('nccl:', nccl_version) print('xpu:', xpu_version) print('xpu_xccl:', xpu_xccl_version) print('xpu_xhpc:', xpu_xhpc_version) @@ -534,6 +536,9 @@ def show(): def mkl(): return with_mkl +def nccl(): + return nccl_version + def cuda(): """Get cuda version of paddle package. @@ -659,6 +664,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), + 'nccl': env_dict.get("NCCL_VERSION"), 'rc': RC, 'version': env_dict.get("PADDLE_VERSION"), 'cuda': get_cuda_version(), From 7f4d5c8f7c8c6946f1c1cc359800dd86e7a99805 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 5 Feb 2024 19:37:18 +0800 Subject: [PATCH 02/18] Fix nccl_version api --- .../distributed/auto_parallel/static/operators/common.py | 3 +-- python/setup.py.in | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 5bf5c7738779e..50247ba92812d 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -507,8 +507,7 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): need_scale = dist_ctx.gradient_scale # With nccl_version > 2.10.00, we can use c_allreduce_avg to replace c_allreduce_sum and eliminate the scale op. - print(f"nccl_version: {paddle.version.nccl()}") - if need_scale and paddle.version.nccl() > 21000: + if need_scale and int(paddle.version.nccl()) > 21000: allreduce_type = "c_allreduce_avg" need_scale = False diff --git a/python/setup.py.in b/python/setup.py.in index 70b8bacc62188..19e9d9ed9288f 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -56,7 +56,7 @@ def get_minor(): def get_nccl_version(): if '@WITH_NCCL@' == 'ON': - return '@NCCL_VERSION@' + return @NCCL_VERSION@ return 0 def get_patch(): From 81d2064c1b17392747cd9079a1f2ec2e68b55100 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 6 Feb 2024 11:04:42 +0800 Subject: [PATCH 03/18] Fix nccl_version api --- python/env_dict.py.in | 2 +- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/env_dict.py.in b/python/env_dict.py.in index bd9617b9b7c82..ee2ddf94c63c6 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -1,5 +1,5 @@ env_dict={ - 'NCCL_VERSION':'@NCCL_VERSION@', + 'NCCL_VERSION':@NCCL_VERSION@, 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', 'PADDLE_VERSION':'@PADDLE_VERSION@', 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 60c2988164384..700a726b736a8 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -938,7 +938,7 @@ def _fuse_overlap_parameter_comm_stage_two(self, sharding_info): sync=False, op_namescope="sharding_stage2_broadcast_dep", ) - if self.enable_overlap: + if self.enable_overlap and depend_op is not None: depend_op.dist_attr.execution_stream = comm_stream depend_op.dist_attr.scheduling_priority = ( self.comm_op_scheduling_priority From 659026959c2baf0aad03047e7ce20367b54287e9 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 6 Feb 2024 12:35:15 +0800 Subject: [PATCH 04/18] Fix nccl_version api --- python/env_dict.py.in | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/env_dict.py.in b/python/env_dict.py.in index ee2ddf94c63c6..bd9617b9b7c82 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -1,5 +1,5 @@ env_dict={ - 'NCCL_VERSION':@NCCL_VERSION@, + 'NCCL_VERSION':'@NCCL_VERSION@', 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', 'PADDLE_VERSION':'@PADDLE_VERSION@', 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', diff --git a/setup.py b/setup.py index fb5a4e8dc17ac..cdd4cb2fbf1c6 100644 --- a/setup.py +++ b/setup.py @@ -664,7 +664,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), - 'nccl': env_dict.get("NCCL_VERSION"), + 'nccl': int(env_dict.get("NCCL_VERSION")), 'rc': RC, 'version': env_dict.get("PADDLE_VERSION"), 'cuda': get_cuda_version(), From 92c893231314d98ab1db855c7187369078cffa66 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 6 Feb 2024 14:14:19 +0800 Subject: [PATCH 05/18] Update code --- .../distributed/passes/auto_parallel_sharding.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 700a726b736a8..532b50a387334 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1006,9 +1006,10 @@ def op_depend_on_group(op, group): while i < len(ops): op = ops[i] if is_data_parallel_reduce_op(op): - assert ( - op.type == "c_reduce_sum" - ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + assert op.type in [ + "c_reduce_avg", + "c_reduce_sum", + ], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" grad_name = op.output_arg_names[0] param_name = _get_base_name_from_grad_name(grad_name) @@ -1041,9 +1042,10 @@ def op_depend_on_group(op, group): param_name ): cur_group.is_in_local_shard = True - assert ( - ops[i + 1].type == "c_allreduce_sum" - ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" + assert ops[i + 1].type in [ + "c_allreduce_avg", + "c_allreduce_sum", + ], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel" assert ( ops[i + 1].output_arg_names[0] == grad_name ), "Hybrid Sharding with Data-Parallel should sync same gradient var" @@ -1487,7 +1489,7 @@ def _insert_reduce_op( ), f"root id should be a positive int, but now root id is {root_id}" new_op = block._insert_op_without_sync( insert_idx, - type='c_reduce_sum', + type=op_type, inputs={'X': [reduce_var]}, outputs={'Out': [reduce_var]}, attrs={ From 3c11a689fb49232d83fac8edbde97f0bb185d2ce Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 6 Feb 2024 19:26:08 +0800 Subject: [PATCH 06/18] Update code --- python/env_dict.py.in | 1 + setup.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/env_dict.py.in b/python/env_dict.py.in index bd9617b9b7c82..92e6b6bf3d64c 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -5,6 +5,7 @@ env_dict={ 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', 'WITH_GPU':'@WITH_GPU@', + 'WITH_NCCL':'@WITH_NCCL' 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', diff --git a/setup.py b/setup.py index cdd4cb2fbf1c6..869bf61f76ffd 100644 --- a/setup.py +++ b/setup.py @@ -344,6 +344,12 @@ def get_patch(): return str(_get_version_detail(2)) +def get_nccl_version(): + if env_dict.get("WITH_NCCL") == 'ON': + return int(env_dict.get("NCCL_VERSION")) + return 0 + + def get_cuda_version(): with_gpu = env_dict.get("WITH_GPU") if with_gpu == 'ON': @@ -664,7 +670,7 @@ def cinn(): 'major': get_major(), 'minor': get_minor(), 'patch': get_patch(), - 'nccl': int(env_dict.get("NCCL_VERSION")), + 'nccl': get_nccl_version(), 'rc': RC, 'version': env_dict.get("PADDLE_VERSION"), 'cuda': get_cuda_version(), From 29e69bcdc3d75930d26c241d58b613c8fe15e280 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 6 Feb 2024 19:28:30 +0800 Subject: [PATCH 07/18] Fix typos --- python/env_dict.py.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/env_dict.py.in b/python/env_dict.py.in index 92e6b6bf3d64c..b13582484e96e 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -5,7 +5,7 @@ env_dict={ 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', 'WITH_GPU':'@WITH_GPU@', - 'WITH_NCCL':'@WITH_NCCL' + 'WITH_NCCL':'@WITH_NCCL@' 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', From f7f2e27ba1b878928bd340078b9d4031ae614610 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 19 Feb 2024 15:18:52 +0800 Subject: [PATCH 08/18] Update code --- python/env_dict.py.in | 2 +- python/paddle/distributed/auto_parallel/static/dist_op.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/env_dict.py.in b/python/env_dict.py.in index b13582484e96e..fd8105402f4ee 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -5,7 +5,7 @@ env_dict={ 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', 'WITH_GPU':'@WITH_GPU@', - 'WITH_NCCL':'@WITH_NCCL@' + 'WITH_NCCL':'@WITH_NCCL@', 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', diff --git a/python/paddle/distributed/auto_parallel/static/dist_op.py b/python/paddle/distributed/auto_parallel/static/dist_op.py index b27e27ee98330..8d28c43eef4d7 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_op.py +++ b/python/paddle/distributed/auto_parallel/static/dist_op.py @@ -130,6 +130,8 @@ def __str__(self): f", process_mesh ({annotated_str}): {self.dist_attr.process_mesh}" ) + str += f" , execution_stream: {self.dist_attr.execution_stream}" + for arg_name in self.serial_op.desc.input_arg_names(): try: dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) From a5b525e3bb52c96a29844aef0d132cee9f07dcfe Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 20 Feb 2024 14:15:58 +0800 Subject: [PATCH 09/18] Add dependency for reduce_avg in sharding --- .../distributed/auto_parallel/static/utils.py | 3 +- .../passes/auto_parallel_sharding.py | 48 +++++++++++++++---- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 2fe48dcdfa5f6..119bcc67ffa9a 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2124,12 +2124,13 @@ def insert_dependencies_for_two_ops( is_recompute=False, sync=False, op_namescope=None, + skip_insert_when_sequential_run=True, ): """ dependency: prior_op should be run before posterior_op """ - if is_sequential_run(): + if skip_insert_when_sequential_run and is_sequential_run(): return assert ( diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 532b50a387334..9a7caf26b6f9c 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1227,7 +1227,7 @@ def _overlap_grad_comm( grad_comm_op_to_stream_idx = {} for idx, op in enumerate(ops): if is_data_parallel_reduce_op(op): - if op.type == "c_allreduce_sum": + if op.type in ["c_allreduce_avg", "c_allreduce_sum"]: continue stream_idx = reduce_op_count % self.grad_comm_stream_num grad_comm_op_to_stream_idx[op] = stream_idx @@ -1253,6 +1253,7 @@ def _overlap_grad_comm( grad_group.vars[-1], grad_group.coalesce_var, comm_stream, + "sharding_grad_comm_dep", ) ] # post dep @@ -1265,6 +1266,7 @@ def _overlap_grad_comm( grad_group.coalesce_var, grad_group.vars, comm_stream, + "sharding_grad_comm_dep", ) ) @@ -1273,11 +1275,13 @@ def _overlap_grad_comm( op.dist_attr.scheduling_priority = ( self.comm_op_scheduling_priority ) - op._set_attr("ring_id", comm_group.id) if self.sharding_hybrid_dp and grad_group.is_in_local_shard: next_op = ops[idx + 1] - assert next_op.type == "c_allreduce_sum" + assert next_op.type in [ + "c_allreduce_avg", + "c_allreduce_sum", + ] assert next_op.output("Out")[0] == reduce_varname # FIXME hybrid sharding-dp support multi comm & stream in feature # next_op._set_attr("ring_id", comm_group.id) @@ -1287,6 +1291,20 @@ def _overlap_grad_comm( ) idx += 1 + if ( + op.type == "c_reduce_avg" + and not grad_group.is_in_local_shard + ): + dep_map[idx].append( + ( + idx, + reduce_varname, + reduce_varname, + None, + "sharding_reduce_avg_dep", + ) + ) + reduce_op_count += 1 idx += 1 @@ -1294,7 +1312,17 @@ def _overlap_grad_comm( # insert deps indice = sorted(dep_map.keys(), reverse=True) for i in indice: - for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]: + for ( + idx, + prior_vars, + post_vars, + comm_stream, + op_namescope, + ) in dep_map[i][::-1]: + skip_insert_when_sequential_run = ( + False if op_namescope == "sharding_reduce_avg_dep" else True + ) + depend_op = insert_dependencies_for_vars( block, idx, @@ -1307,12 +1335,14 @@ def _overlap_grad_comm( ], # hack to avoid initialize the dist attr for coalesce var is_recompute=False, sync=False, - op_namescope="sharding_grad_comm_dep", - ) - depend_op.dist_attr.execution_stream = comm_stream - depend_op.dist_attr.scheduling_priority = ( - self.comm_op_scheduling_priority + op_namescope=op_namescope, + skip_insert_when_sequential_run=skip_insert_when_sequential_run, ) + if depend_op is not None and comm_stream is not None: + depend_op.dist_attr.execution_stream = comm_stream + depend_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) # hierarchical grad comm if self.enable_hierarchical_comm: From 8653e0c0f323a53a7f4137e949edd8d803995c75 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 21 Feb 2024 14:55:35 +0800 Subject: [PATCH 10/18] Update code --- .../distributed/auto_parallel/static/utils.py | 6 ++-- .../passes/auto_parallel_sharding.py | 29 ++++++++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 119bcc67ffa9a..b30da39e68657 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2124,13 +2124,12 @@ def insert_dependencies_for_two_ops( is_recompute=False, sync=False, op_namescope=None, - skip_insert_when_sequential_run=True, ): """ dependency: prior_op should be run before posterior_op """ - if skip_insert_when_sequential_run and is_sequential_run(): + if is_sequential_run(): return assert ( @@ -2194,12 +2193,13 @@ def insert_dependencies_for_vars( sync=False, op_namescope=None, use_nop=False, + skip_insert_when_sequential_run=True, ): """ dependency: op that generates prior_vars should be run before op that generates post_vars """ - if is_sequential_run(): + if skip_insert_when_sequential_run and is_sequential_run(): return if isinstance(prior_vars, Variable): diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 9a7caf26b6f9c..6101e744334c6 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -33,6 +33,7 @@ is_dep_skip_op, is_forward_op, is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) @@ -1254,6 +1255,7 @@ def _overlap_grad_comm( grad_group.coalesce_var, comm_stream, "sharding_grad_comm_dep", + op.dist_attr, ) ] # post dep @@ -1267,6 +1269,7 @@ def _overlap_grad_comm( grad_group.vars, comm_stream, "sharding_grad_comm_dep", + op.dist_attr, ) ) @@ -1295,13 +1298,16 @@ def _overlap_grad_comm( op.type == "c_reduce_avg" and not grad_group.is_in_local_shard ): + if idx not in dep_map: + dep_map[idx] = [] dep_map[idx].append( ( - idx, - reduce_varname, - reduce_varname, + idx + 1, + grad_group.coalesce_var, + grad_group.coalesce_var, None, "sharding_reduce_avg_dep", + op.dist_attr, ) ) @@ -1318,6 +1324,7 @@ def _overlap_grad_comm( post_vars, comm_stream, op_namescope, + dist_attr, ) in dep_map[i][::-1]: skip_insert_when_sequential_run = ( False if op_namescope == "sharding_reduce_avg_dep" else True @@ -1338,11 +1345,19 @@ def _overlap_grad_comm( op_namescope=op_namescope, skip_insert_when_sequential_run=skip_insert_when_sequential_run, ) - if depend_op is not None and comm_stream is not None: - depend_op.dist_attr.execution_stream = comm_stream - depend_op.dist_attr.scheduling_priority = ( - self.comm_op_scheduling_priority + + if depend_op is not None: + naive_set_dist_op_attr_for_program_by_mesh( + depend_op, + process_mesh=dist_attr.process_mesh, + ctx=self._dist_context, + chunk_id=dist_attr.chunk_id, ) + if comm_stream is not None: + depend_op.dist_attr.execution_stream = comm_stream + depend_op.dist_attr.scheduling_priority = ( + self.comm_op_scheduling_priority + ) # hierarchical grad comm if self.enable_hierarchical_comm: From b2724a61b6b81040922a687a08a50db5235c8231 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Fri, 23 Feb 2024 20:33:09 +0800 Subject: [PATCH 11/18] Update code --- .../passes/auto_parallel_sharding.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 6101e744334c6..fce1d8eb1d629 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1087,6 +1087,18 @@ def op_depend_on_group(op, group): persistable=False, stop_gradient=True, ) + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + group.vars[0] + ) + ) + set_var_dist_attr( + self._dist_context, + group.coalesce_var, + ref_dist_attr.dims_mapping, + ref_dist_attr.process_mesh, + ref_dist_attr.chunk_id, + ) coalesce_op_map[group.coalesce_op_idx] = group last_reduce_op_idx = group.reduce_op_indices.pop() modify_reduce_op_map[last_reduce_op_idx] = group @@ -1162,6 +1174,20 @@ def op_depend_on_group(op, group): OP_ROLE_KEY: OpRole.Backward, }, ) + + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + group.coalesce_var + ) + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + coalesce_op, + ref_dist_attr.process_mesh, + ref_dist_attr.dims_mapping, + self._dist_context, + chunk_id=ref_dist_attr.chunk_id, + ) + depend_op = insert_dependencies_for_vars( block, idx, From c759901f518aaf94765c230b693b24077fa878b1 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 27 Feb 2024 18:36:06 +0800 Subject: [PATCH 12/18] Updatte code --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index fce1d8eb1d629..692ba05eb043b 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1097,7 +1097,7 @@ def op_depend_on_group(op, group): group.coalesce_var, ref_dist_attr.dims_mapping, ref_dist_attr.process_mesh, - ref_dist_attr.chunk_id, + chunk_id=ref_dist_attr.chunk_id, ) coalesce_op_map[group.coalesce_op_idx] = group last_reduce_op_idx = group.reduce_op_indices.pop() From 557770cfc79487f9b06dbbd40d7d41d39cf7c4e4 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 28 Feb 2024 20:59:27 +0800 Subject: [PATCH 13/18] Fix CI errors --- .../auto_parallel_data_parallel_optimization.py | 14 ++++++++++---- test/auto_parallel/test_dist_embedding.py | 5 ++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 9b26a0980e55f..97962f248cf89 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -440,7 +440,12 @@ def op_depend_on_group(op, group): def _update_program(self, grad_groups): block = default_main_program().global_block() - remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute'] + remove_op_types = [ + 'scale', + 'c_allreduce_avg', + 'c_allreduce_sum', + 'c_wait_compute', + ] for i, group in enumerate(grad_groups[::-1]): # skip unfused big tensor @@ -492,9 +497,10 @@ def _update_program(self, grad_groups): ) allreduce_op = block.ops[group.allreduce_op_idx] - assert ( - allreduce_op.type == 'c_allreduce_sum' - ), f"should found c_allreduce_sum op but found {str(allreduce_op)}" + assert allreduce_op.type in [ + 'c_allreduce_avg', + 'c_allreduce_sum', + ], f"should found c_allreduce_avg or c_allreduce_sum op but found {str(allreduce_op)}" allreduce_op_dist_attr = ( self.dist_context.get_op_dist_attr_for_program(allreduce_op) ) diff --git a/test/auto_parallel/test_dist_embedding.py b/test/auto_parallel/test_dist_embedding.py index f8dbd0fc9494d..b60b902f5516b 100644 --- a/test/auto_parallel/test_dist_embedding.py +++ b/test/auto_parallel/test_dist_embedding.py @@ -88,9 +88,8 @@ def test_lookup_table_v1_mp_dp(self): 'fill_constant', 'reduce_mean_grad', 'c_embedding_grad', - 'c_allreduce_sum', - 'scale', - ] + 'c_allreduce_avg', + ], f"Unexpexted op types: {op_types}" if __name__ == "__main__": From ca09321900a281ac7714febe129a650ab7cb1c25 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Thu, 29 Feb 2024 10:57:17 +0800 Subject: [PATCH 14/18] Register reduce_avg to pir --- .../framework/new_executor/pir_interpreter.cc | 4 ++++ .../pir/dialect/op_generator/ops_api_gen.py | 4 ++++ paddle/fluid/pir/dialect/operator/ir/ops.yaml | 20 +++++++++++++++++++ .../fluid/pir/dialect/operator/utils/utils.cc | 4 ++++ 4 files changed, 32 insertions(+) diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index fcb190a799922..6c5f4ea89940e 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -439,10 +439,12 @@ void PirInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { "pd_op.c_softmax_with_cross_entropy", "pd_op.c_allgather", + "pd_op.c_allreduce_avg", "pd_op.c_allreduce_max", "pd_op.c_allreduce_min", "pd_op.c_allreduce_sum", "pd_op.c_allreduce_prod", + "pd_op.c_reduce_avg", "pd_op.c_reduce_max", "pd_op.c_reduce_min", "pd_op.c_reduce_prod", @@ -509,10 +511,12 @@ void PirInterpreter::UpdateNcclOpNum() { "pd_op.reduce_grad", "pd_op.c_softmax_with_cross_entropy_", "pd_op.c_allgather_", + "pd_op.c_allreduce_avg_", "pd_op.c_allreduce_max_", "pd_op.c_allreduce_min_", "pd_op.c_allreduce_sum_", "pd_op.c_allreduce_prod_", + "pd_op.c_reduce_avg_", "pd_op.c_reduce_max_", "pd_op.c_reduce_min_", "pd_op.c_reduce_prod_", diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 0212d41523444..d5c5c4f4102be 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -112,6 +112,8 @@ 'add_n_', 'add_n_with_kernel', 'c_allgather', + 'c_allreduce_avg', + 'c_allreduce_avg_', 'c_allreduce_max', 'c_allreduce_min', 'c_allreduce_min_', @@ -151,6 +153,8 @@ 'soft_relu', 'uniform_random_batch_size_like', 'match_matrix_tensor', + 'c_reduce_avg', + 'c_reduce_avg_', 'c_reduce_min', 'c_reduce_min_', 'push_sparse_v2', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 4ff53993ac4bc..3becb18fc581e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -156,6 +156,16 @@ kernel : func : c_allgather +- op : c_allreduce_avg + args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param : [x] + kernel : + func : c_allreduce_avg + inplace : (x -> out) + - op : c_allreduce_max args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -236,6 +246,16 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_avg + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_avg + inplace : (x -> out) + - op : c_reduce_min args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 89b57608c1aec..3c0a3a1b3c875 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -52,6 +52,8 @@ const std::unordered_set LegacyOpList = { CAllreduceProd_Op::name(), CAllreduceSumOp::name(), CAllreduceSum_Op::name(), + CAllreduceAvgOp::name(), + CAllreduceAvg_Op::name(), CReduceSumOp::name(), CReduceSum_Op::name(), CAllreduceMax_Op::name(), @@ -80,6 +82,8 @@ const std::unordered_set LegacyOpList = { paddle::onednn::dialect::QuantizeOp::name(), paddle::onednn::dialect::RequantizeOp::name(), #endif + CReduceAvgOp::name(), + CReduceAvg_Op::name(), CReduceMinOp::name(), PushSparseV2Op::name()}; From edee0b9ad471b10a1c639068e7c56d8fbeda2a21 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Thu, 29 Feb 2024 19:58:16 +0800 Subject: [PATCH 15/18] Add op compat yaml --- paddle/phi/api/yaml/op_compat.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index cd296f7c302b9..be47d0ab12962 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3437,6 +3437,12 @@ outputs : out: Out +- op: c_allreduce_avg + inputs : + x : X + outputs : + out: Out + - op: c_allreduce_max inputs : x : X @@ -3473,6 +3479,12 @@ outputs : out: Out +- op: c_reduce_avg + inputs : + x : X + outputs : + out: Out + - op: c_reduce_min inputs : x : X From 3ac66e01733d74f1d1d88eba72191feacdd31fbc Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Sat, 2 Mar 2024 17:59:36 +0800 Subject: [PATCH 16/18] Add gradient_scale_using_allreduce_avg args --- .../distributed/auto_parallel/constants.py | 1 + .../auto_parallel/static/dist_context.py | 15 ++++++++ .../auto_parallel/static/engine.py | 5 +++ .../auto_parallel/static/operators/common.py | 7 +++- test/auto_parallel/sharding_pass_unittest.py | 35 +++++++++++++++++-- test/auto_parallel/test_dist_embedding.py | 3 +- 6 files changed, 61 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index bcc64a50ae218..2fad0a278aeff 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -42,6 +42,7 @@ def set_field_default_config(category, field, default_value): BASE = "base" set_field_default_config(BASE, "auto_mode", "semi") set_field_default_config(BASE, "gradient_scale", True) +set_field_default_config(BASE, "gradient_scale_using_allreduce_avg", False) set_field_default_config(BASE, "use_cache", True) set_field_default_config(BASE, "return_numpy", True) set_field_default_config(BASE, "all_ranks", False) diff --git a/python/paddle/distributed/auto_parallel/static/dist_context.py b/python/paddle/distributed/auto_parallel/static/dist_context.py index eefc0d332957f..12d88ba779d3f 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_context.py +++ b/python/paddle/distributed/auto_parallel/static/dist_context.py @@ -127,6 +127,9 @@ def __init__( # flag whether scale gradient with dp size self._gradient_scale = True + # whether use allreduce_avg to scale gradient, i.e., allreduce_sum + scale -> allreduce_avg + self._gradient_scale_using_allreduce_avg = False + # A flag indicates whether the used parallelism is data parallel self._data_parallel = False @@ -220,6 +223,18 @@ def gradient_scale(self): def gradient_scale(self, gs): self._gradient_scale = gs + @property + def gradient_scale_using_allreduce_avg(self): + return self._gradient_scale_using_allreduce_avg + + @gradient_scale_using_allreduce_avg.setter + def gradient_scale_using_allreduce_avg( + self, gradient_scale_using_allreduce_avg + ): + self._gradient_scale_using_allreduce_avg = ( + gradient_scale_using_allreduce_avg + ) + @property def data_parallel(self): return self._data_parallel diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 99daa366b32e5..4fef710e8fd73 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -774,6 +774,11 @@ def _build(self, mode): self._json_config, ) self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale + self._dist_contexts[ + mode + ].gradient_scale_using_allreduce_avg = ( + self._strategy.gradient_scale_using_allreduce_avg + ) self._fwd_main_progs[mode] = serial_main_prog.clone() def _optimization_tuning(self, mode, dataset, batch_size): diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 50247ba92812d..17dae832485f7 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -505,9 +505,14 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names): allreduce_type = "c_allreduce_sum" need_scale = dist_ctx.gradient_scale + scale_using_allreduce_avg = dist_ctx.gradient_scale_using_allreduce_avg # With nccl_version > 2.10.00, we can use c_allreduce_avg to replace c_allreduce_sum and eliminate the scale op. - if need_scale and int(paddle.version.nccl()) > 21000: + if ( + need_scale + and scale_using_allreduce_avg + and int(paddle.version.nccl()) > 21000 + ): allreduce_type = "c_allreduce_avg" need_scale = False diff --git a/test/auto_parallel/sharding_pass_unittest.py b/test/auto_parallel/sharding_pass_unittest.py index 82d17e821b7db..14ee0a22e4c0f 100644 --- a/test/auto_parallel/sharding_pass_unittest.py +++ b/test/auto_parallel/sharding_pass_unittest.py @@ -24,9 +24,10 @@ paddle.enable_static() -def apply_pass(use_sharding=False, stage=None): +def apply_pass(use_sharding=False, stage=None, use_allreduce_avg=False): strategy = auto.Strategy() strategy.auto_mode = "semi" + strategy.gradient_scale_using_allreduce_avg = use_allreduce_avg # strategy.reinit = True if use_sharding: sharding = strategy.sharding @@ -67,10 +68,12 @@ def init(self, engine): np.random.seed(2022) random.seed(2022) - def get_engine(self, use_sharding=False, stage=None): + def get_engine( + self, use_sharding=False, stage=None, use_allreduce_avg=False + ): reset_prog() - strategy = apply_pass(use_sharding, stage) + strategy = apply_pass(use_sharding, stage, use_allreduce_avg) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) # NOTE: setting opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) will cause precision problem opt = paddle.optimizer.AdamW(learning_rate=0.00001) @@ -150,6 +153,32 @@ def test_sharding_pass(self): sharding3_losses = np.array(history.history["loss"]) self.check_results(dp_losses, sharding3_losses) + # dp2 training using allreduce avg + dp_engine_using_allreduce_avg = self.get_engine(use_allreduce_avg=True) + dp_engine_using_allreduce_avg.prepare( + inputs_spec=input_spec, labels_spec=label_spec, mode='train' + ) + dp_engine_using_allreduce_avg.save( + "./dp_engine_using_allreduce_avg", training=True + ) + history = dp_engine_using_allreduce_avg.fit( + self.dataset, 3, batch_size=self.batch_size + ) + dp_losses_using_allreduce_avg = np.array(history.history["loss"]) + + # sharding2 stage2 training using allreduce avg + sharding2_engine_using_allreduce_avg = self.get_engine(True, 2) + sharding2_engine_using_allreduce_avg.load( + "./dp_engine_using_allreduce_avg" + ) + history = sharding2_engine_using_allreduce_avg.fit( + self.dataset, 3, batch_size=self.batch_size + ) + sharding2_losses_using_allreduce_avg = np.array(history.history["loss"]) + self.check_results( + dp_losses_using_allreduce_avg, sharding2_losses_using_allreduce_avg + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_dist_embedding.py b/test/auto_parallel/test_dist_embedding.py index b60b902f5516b..7304b06aeb274 100644 --- a/test/auto_parallel/test_dist_embedding.py +++ b/test/auto_parallel/test_dist_embedding.py @@ -88,7 +88,8 @@ def test_lookup_table_v1_mp_dp(self): 'fill_constant', 'reduce_mean_grad', 'c_embedding_grad', - 'c_allreduce_avg', + 'c_allreduce_sum', + 'scale', ], f"Unexpexted op types: {op_types}" From 38ddbb0c83db9f0f5896b0390b475c7e2e711de1 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Sun, 3 Mar 2024 17:33:20 +0800 Subject: [PATCH 17/18] Fix CI errors --- test/auto_parallel/sharding_pass_unittest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auto_parallel/sharding_pass_unittest.py b/test/auto_parallel/sharding_pass_unittest.py index 14ee0a22e4c0f..762fb6e239582 100644 --- a/test/auto_parallel/sharding_pass_unittest.py +++ b/test/auto_parallel/sharding_pass_unittest.py @@ -167,7 +167,7 @@ def test_sharding_pass(self): dp_losses_using_allreduce_avg = np.array(history.history["loss"]) # sharding2 stage2 training using allreduce avg - sharding2_engine_using_allreduce_avg = self.get_engine(True, 2) + sharding2_engine_using_allreduce_avg = self.get_engine(True, 2, True) sharding2_engine_using_allreduce_avg.load( "./dp_engine_using_allreduce_avg" ) From 1d3918bca8e352971bc9b61005b828f21a353e48 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Sun, 3 Mar 2024 18:37:25 +0800 Subject: [PATCH 18/18] Add NOTE --- .../distributed/passes/auto_parallel_sharding.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 692ba05eb043b..d00736629d73e 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1320,6 +1320,17 @@ def _overlap_grad_comm( ) idx += 1 + # NOTE(Ruibiao): Why add dependecy here? + # It is hack to delay GC for coalesce_var, which significantly reduce memory usage. + # With the pattern of reduce_sum + scale, the coalesce_var is used by the reduce_sum + # op on the comm-stream, and then released by the scale op on the comp-stream. Since + # the generated and released op are both in comp-stream, the allocation of the + # coalesce_var can be fast-GC and reused by subsequent comp-op. However in reduce_avg + # parrent, the coalesce_var is released on the reduce_avg op in comm-stream, + # triggering a cross-stream GC. In such case, an event is recorded on the underlying + # allocation, and the memory is unable to reused by other comp-ops, resulting in an + # increase in memory usage. For more details, see the code of StreamSafeCUDAAllocator. + # This issue should be fixed using CUDAMallocAsyncAllocator in the future. if ( op.type == "c_reduce_avg" and not grad_group.is_in_local_shard