Skip to content

Commit

Permalink
Move group and all reduce from collective to communication (#45848)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored and HermitSun committed Oct 12, 2022
1 parent 9641b93 commit 868b1e1
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 184 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts) {
return AllReduce(inputs, outputs, opts, true);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts = AllreduceOptions()) override;

std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) override;

std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;

Expand Down
173 changes: 11 additions & 162 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,54 +52,12 @@
from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding
from .communication.comm_utils import ReduceOp
from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp

__all__ = []


class Group():
"""
The abstract representation of group.
"""

def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
self.rank = rank
self.nranks = rank_num
self.id = id
self.ranks = ranks
self.pg = pg
self.name = name

def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True

def get_group_rank(self, rank):
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
return -1

@property
def process_group(self):
return self.pg

@property
def world_size(self):
return self.nranks if self.rank >= 0 else -1

def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str


_global_env = None


Expand Down Expand Up @@ -147,9 +105,8 @@ def _get_group_map():
global _group_map
if _global_env_gid not in _group_map:
genv = _get_global_env()
_group_map[_global_env_gid] = Group(genv.rank,
genv.world_size,
ranks=list(range(genv.world_size)))
_group_map[_global_env_gid] = Group(genv.rank, 0,
list(range(genv.world_size)))
return _group_map


Expand Down Expand Up @@ -197,19 +154,6 @@ def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)


def _get_reduce_op(reduce_op, func_name):
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for {}.".format(func_name))


def get_group(id=0):
"""
Expand Down Expand Up @@ -451,10 +395,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
else:
rank = -1
pg = None
group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
group = Group(rank, gid, ranks, pg=pg, name=group_name)
_group_map_by_name[group_name] = group
_group_map[gid] = group
_group_map_backend[group] = backend
#TODO: The method below is a new method for group management, will replace the previous
# three in the future.
_add_new_group(group)

# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
Expand All @@ -476,13 +423,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
ring_id = _new_ring_id()

if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks)
gp = Group(-1, ring_id, ranks)
_group_map[ring_id] = gp
else:
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
gp = Group(group_rank, ring_id, ranks)
_group_map[ring_id] = gp

if group_size >= 2:
Expand Down Expand Up @@ -748,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True):
})


def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if group is not None and not group.is_member():
return

if in_dygraph_mode():
op_type = _get_reduce_op(op, "all_reduce")
group = _get_default_group() if group is None else group
task = group.process_group.allreduce(tensor, op_type)
if sync_op:
task.wait()
return None
else:
return task

use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if _non_static_mode():
if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))

check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
], 'all_reduce')
if op == ReduceOp.SUM:
op_type = 'c_allreduce_sum'
elif op == ReduceOp.MAX:
op_type = 'c_allreduce_max'
elif op == ReduceOp.MIN:
op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod'
if not isinstance(ring_id, int):
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream
})


def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Expand Down
87 changes: 87 additions & 0 deletions python/paddle/distributed/communication/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid.framework as framework
from paddle.distributed.communication import stream as stream
from paddle.distributed.communication.reduce import ReduceOp


def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
Return a task object.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_reduce(tensor,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False)

# code below will be removed after we remove the old dygraph
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if op == ReduceOp.SUM:
return paddle._legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return paddle._legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return paddle._legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return paddle._legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream,
'ring_id', ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))
Loading

0 comments on commit 868b1e1

Please sign in to comment.