Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph] Remove unrequired UT cases of DP in eager mode #41413

Merged
merged 5 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, build_groups
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from collections import OrderedDict
from .log_util import logger

Expand Down Expand Up @@ -58,6 +59,30 @@ def _apply_collective_grads(parameters, comm_group):
_split_tensors(coalesced_grads_and_vars)


def _apply_collective_grads_eager(parameters, comm_group):
grad_var_set = set()
grad_vars = []

for param in parameters:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
assert not g_var.is_sparse(
), "Now, it doesn't support sparse parameters"
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)

coalesced_grads_and_vars = build_groups(grad_vars, 128 * 1024 * 1024)

div_factor = 1.0 / comm_group.nranks
for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad.scale_(div_factor)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)


def _broadcast_data_help(data, shape, dtype, hcg):
model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank()
Expand Down Expand Up @@ -115,10 +140,17 @@ def broadcast_dp_parameters(model, hcg):


def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
if _in_legacy_dygraph():
data_parallel_group = None if hcg is None else hcg.get_data_parallel_group(
)
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
elif in_dygraph_mode():
assert hcg is None, "It's not support to use hcg in EagerDygraph now."
data_parallel_group = paddle.distributed.collective._get_default_group()
with framework.no_grad():
_apply_collective_grads_eager(parameter_list, data_parallel_group)


def sharding_reduce_gradients(parameter_list, hcg):
Expand Down
34 changes: 23 additions & 11 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from contextlib import contextmanager

import paddle
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
Expand Down Expand Up @@ -307,17 +308,28 @@ def _reshape_inplace(x, shape):

@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len,
'axis': 0})
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
_reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape
if _in_legacy_dygraph():
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len,
'axis': 0})
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
_reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape
elif in_dygraph_mode():
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
attrs = ()
attrs += ('sections', grad_var_len)
attrs += ('axis', 0)
_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
g_var.reshape_(shape=g_shape)
assert g_var.shape == g_shape


def scale_loss(loss):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np
import paddle.distributed as dist
from paddle.fluid.dygraph.nn import Linear
from paddle.autograd import PyLayer
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

batch = 5
Expand All @@ -43,6 +44,20 @@ def backward(ctx, dy):
return grad


class cus_tanh_eager(EagerPyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y

@staticmethod
def backward(ctx, dy):
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad


class SimpleNet(paddle.nn.Layer):
def __init__(self, train_id, model_id):
super(SimpleNet, self).__init__()
Expand All @@ -55,7 +70,10 @@ def __init__(self, train_id, model_id):

def forward(self, inputs):
if self.model_id == 0:
inputs = cus_tanh.apply(inputs)
if in_dygraph_mode():
inputs = cus_tanh_eager.apply(inputs)
elif _in_legacy_dygraph():
inputs = cus_tanh.apply(inputs)
else:
inputs = self.tanh(inputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import subprocess

from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc
from paddle.fluid.framework import _test_eager_guard


def get_cluster_from_args(selected_gpus):
Expand Down Expand Up @@ -205,6 +206,8 @@ def test_multiple_gpus_dynamic(self):

class TestDataParallelWithPyLayer(TestMultipleGpus):
def test_parallel_dygraph_dataparallel_with_pylayer(self):
with _test_eager_guard():
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')
self.run_mnist_2gpu('parallel_dygraph_dataparallel_with_pylayer.py')


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,5 @@ def test_sparse_embedding_fp64(self):
log_name=flag_name)


class TestParallelDygraphSparseEmdeddingEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True

def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)


class TestParallelDygraphSparseEmdeddingEagerFP64_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True

def test_sparse_embedding_fp64(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_fp64.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,5 @@ def test_sparse_embedding(self):
log_name=flag_name)


class TestParallelDygraphSparseEmdeddingOverHeightEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True

def test_sparse_embedding(self):
self.check_with_place(
"parallel_dygraph_sparse_embedding_over_height.py",
delta=1e-7,
check_error_log=True,
log_name=flag_name)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,5 @@ def test_transformer(self):
log_name=flag_name)


class TestParallelDygraphTransformerEager_GLOO(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._eager_mode = True
self._gloo_mode = True
self._dygraph = True

def test_transformer(self):
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)


if __name__ == "__main__":
unittest.main()