Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using allreduce_avg to eliminate scale in auto parallel DP #61622

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,12 @@ void PirInterpreter::UpdateNcclOpNum() {
static std::set<std::string> nccl_op_set = {
"pd_op.c_softmax_with_cross_entropy",
"pd_op.c_allgather",
"pd_op.c_allreduce_avg",
"pd_op.c_allreduce_max",
"pd_op.c_allreduce_min",
"pd_op.c_allreduce_sum",
"pd_op.c_allreduce_prod",
"pd_op.c_reduce_avg",
"pd_op.c_reduce_max",
"pd_op.c_reduce_min",
"pd_op.c_reduce_prod",
Expand Down Expand Up @@ -509,10 +511,12 @@ void PirInterpreter::UpdateNcclOpNum() {
"pd_op.reduce_grad",
"pd_op.c_softmax_with_cross_entropy_",
"pd_op.c_allgather_",
"pd_op.c_allreduce_avg_",
"pd_op.c_allreduce_max_",
"pd_op.c_allreduce_min_",
"pd_op.c_allreduce_sum_",
"pd_op.c_allreduce_prod_",
"pd_op.c_reduce_avg_",
"pd_op.c_reduce_max_",
"pd_op.c_reduce_min_",
"pd_op.c_reduce_prod_",
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_avg_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle

namespace paddle {
namespace operators {

class CAllReduceAvgOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Avg"; }
};

DECLARE_INPLACE_OP_INFERER(AllreduceAvgInplaceInferer, {"X", "Out"});

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_avg,
ops::CAllReduceOp,
ops::CAllReduceAvgOpMaker,
ops::AllreduceAvgInplaceInferer)
35 changes: 35 additions & 0 deletions paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceAvg, kRedAvg)
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

PD_REGISTER_STRUCT_KERNEL(c_allreduce_avg,
GPU,
ALL_LAYOUT,
ops::CAllReduceAvgCUDAKernel,
float,
double,
int,
int64_t,
plat::float16,
plat::bfloat16) {}
8 changes: 7 additions & 1 deletion paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle {
namespace operators {

enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg };

class CAllReduceOp : public framework::OperatorWithKernel {
public:
Expand Down Expand Up @@ -413,6 +413,12 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
nccl_red_type = ncclProd;
break;

#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
case kRedAvg:
nccl_red_type = ncclAvg;
break;
#endif

default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid reduce type: %d", red_type));
Expand Down
44 changes: 44 additions & 0 deletions paddle/fluid/operators/collective/c_reduce_avg_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_reduce_op.h"

namespace paddle {
namespace framework {
class OpDesc;
template <typename T>
class EmptyGradOpMaker;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle

namespace paddle {
namespace operators {

class CReduceAvgOpMaker : public CReduceOpMaker {
protected:
std::string GetName() const override { return "Avg"; }
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(c_reduce_avg,
ops::CReduceOp,
ops::CReduceAvgOpMaker);
35 changes: 35 additions & 0 deletions paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_reduce_op.h"

namespace paddle {
namespace operators {
DEFINE_C_REDUCE_CUDA_KERNEL(CReduceAvg, kRedAvg);
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

PD_REGISTER_STRUCT_KERNEL(c_reduce_avg,
GPU,
ALL_LAYOUT,
ops::CReduceAvgCUDAKernel,
float,
double,
int,
int64_t,
plat::float16,
plat::bfloat16) {}
8 changes: 7 additions & 1 deletion paddle/fluid/operators/collective/c_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle {
namespace operators {

enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg };

class CReduceOp : public framework::OperatorWithKernel {
public:
Expand Down Expand Up @@ -304,6 +304,12 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
nccl_red_type = ncclProd;
break;

#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
case kRedAvg:
nccl_red_type = ncclAvg;
break;
#endif

default:
PADDLE_ENFORCE_EQ(true,
false,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@
NO_NEED_GEN_STATIC_ONLY_APIS = [
'add_n_',
'c_allgather',
'c_allreduce_avg',
'c_allreduce_avg_',
'c_allreduce_max',
'c_allreduce_min',
'c_allreduce_min_',
Expand Down Expand Up @@ -158,6 +160,8 @@
'soft_relu',
'uniform_random_batch_size_like',
'match_matrix_tensor',
'c_reduce_avg',
'c_reduce_avg_',
'c_reduce_max',
'c_reduce_max_',
'c_reduce_min',
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@
kernel :
func : c_allgather

- op : c_allreduce_avg
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param : [x]
kernel :
func : c_allreduce_avg
inplace : (x -> out)

- op : c_allreduce_max
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
Expand Down Expand Up @@ -218,6 +228,16 @@
func : c_identity
inplace : (x -> out)

- op : c_reduce_avg
args : (Tensor x, int ring_id, int root_id, bool use_calc_stream)
output : Tensor(out)
infer_meta :
func : DistReduceInferMeta
param : [x]
kernel :
func : c_reduce_avg
inplace : (x -> out)

- op : c_reduce_max
args : (Tensor x, int ring_id, int root_id, bool use_calc_stream)
output : Tensor(out)
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ const std::unordered_set<std::string> LegacyOpList = {
CAllreduceProd_Op::name(),
CAllreduceSumOp::name(),
CAllreduceSum_Op::name(),
CAllreduceAvgOp::name(),
CAllreduceAvg_Op::name(),
CReduceSumOp::name(),
CReduceSum_Op::name(),
CAllreduceMax_Op::name(),
Expand Down Expand Up @@ -86,6 +88,8 @@ const std::unordered_set<std::string> LegacyOpList = {
paddle::onednn::dialect::MultiGruOp::name(),
paddle::onednn::dialect::FusionLstmOp::name(),
#endif
CReduceAvgOp::name(),
CReduceAvg_Op::name(),
CReduceMaxOp::name(),
CReduceMinOp::name(),
CReduceProdOp::name(),
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3513,6 +3513,12 @@
outputs :
out: Out

- op: c_allreduce_avg
inputs :
x : X
outputs :
out: Out

- op: c_allreduce_max
inputs :
x : X
Expand Down Expand Up @@ -3549,6 +3555,12 @@
outputs :
out: Out

- op: c_reduce_avg
inputs :
x : X
outputs :
out: Out

- op: c_reduce_max
inputs :
x : X
Expand Down
2 changes: 2 additions & 0 deletions python/env_dict.py.in
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
env_dict={
'NCCL_VERSION':'@NCCL_VERSION@',
'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@',
'PADDLE_VERSION':'@PADDLE_VERSION@',
'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@',
'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@',
'WITH_GPU':'@WITH_GPU@',
'WITH_NCCL':'@WITH_NCCL@',
'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@',
'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@',
'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@',
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def set_field_default_config(category, field, default_value):
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "gradient_scale_using_allreduce_avg", False)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/distributed/auto_parallel/static/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
# flag whether scale gradient with dp size
self._gradient_scale = True

# whether use allreduce_avg to scale gradient, i.e., allreduce_sum + scale -> allreduce_avg
self._gradient_scale_using_allreduce_avg = False

# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False

Expand Down Expand Up @@ -220,6 +223,18 @@ def gradient_scale(self):
def gradient_scale(self, gs):
self._gradient_scale = gs

@property
def gradient_scale_using_allreduce_avg(self):
return self._gradient_scale_using_allreduce_avg

@gradient_scale_using_allreduce_avg.setter
def gradient_scale_using_allreduce_avg(
self, gradient_scale_using_allreduce_avg
):
self._gradient_scale_using_allreduce_avg = (
gradient_scale_using_allreduce_avg
)

@property
def data_parallel(self):
return self._data_parallel
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __str__(self):
f", process_mesh ({annotated_str}): {self.dist_attr.process_mesh}"
)

str += f" , execution_stream: {self.dist_attr.execution_stream}"

for arg_name in self.serial_op.desc.input_arg_names():
try:
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ def _build(self, mode):
self._json_config,
)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._dist_contexts[
mode
].gradient_scale_using_allreduce_avg = (
self._strategy.gradient_scale_using_allreduce_avg
)
self._fwd_main_progs[mode] = serial_main_prog.clone()

def _optimization_tuning(self, mode, dataset, batch_size):
Expand Down
Loading