Skip to content

Commit

Permalink
Fix pipeline in new dygraph (#41937) (#42053)
Browse files Browse the repository at this point in the history
* fix utest

* fix time
  • Loading branch information
ForFishes authored Apr 21, 2022
1 parent 50fd245 commit 7eae657
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# The file has been adapted from the file:
# /~https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/module.py
# Git commit hash: fafc827d643b3eed611e282d909025f16be36601
# We retain the following license from the original files:
# MIT License

# Copyright (c) Microsoft Corporation.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE

import math
import re
import glob
Expand All @@ -24,6 +50,7 @@
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting
from paddle.fluid.framework import in_dygraph_mode

__all__ = []

Expand Down Expand Up @@ -269,15 +296,20 @@ def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items():
param = getattr(self.shared_layers[key], comm['weight_attr'])
# need use trace_op to allreduce weight
with paddle.framework.no_grad():
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': param._grad_ivar()},
outputs={'Out': param._grad_ivar()},
attrs={
'ring_id': comm['group'].id,
'use_calc_stream': True
})
if in_dygraph_mode():
with paddle.framework.no_grad():
paddle.distributed.all_reduce(
param.grad, group=comm['group'])
else:
with paddle.framework.no_grad():
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': param._grad_ivar()},
outputs={'Out': param._grad_ivar()},
attrs={
'ring_id': comm['group'].id,
'use_calc_stream': True
})

def _segment_network(self, seg_method):
logger.info("start segment network..")
Expand Down
57 changes: 30 additions & 27 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..utils.log_util import logger
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
from .pp_utils import p2p_communication as p2p
import paddle.fluid.core as core

__all__ = []

Expand Down Expand Up @@ -238,9 +239,9 @@ def _forward_step(self, input_tensor):
assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
labels = self._load_micro_batch(self.micro_batch_id)
output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance(
output_tensor, paddle.Tensor
), "Currently, loss_fn should obtain Paddle.Tensor dtype"
assert isinstance(output_tensor, (
paddle.Tensor, core.eager.Tensor
)), "Currently, loss_fn should obtain Paddle.Tensor dtype"

with paddle.amp.auto_cast(enable=False):
if self.accumulate_steps > 1:
Expand All @@ -254,31 +255,33 @@ def _forward_step(self, input_tensor):
return output_tensor

def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self.is_last_stage:
assert output_tensor_grad is None
if self.scaler:
paddle.autograd.backward(self.scaler.scale(output_tensor))
else:
paddle.autograd.backward(output_tensor)
else:
if isinstance(output_tensor, tuple):
outputs = [t for t in output_tensor if not t.stop_gradient]
assert len(outputs) == len(output_tensor_grad)
paddle.autograd.backward(
tensors=outputs,
grad_tensors=[t for t in output_tensor_grad])
else:
paddle.autograd.backward(
tensors=[output_tensor], grad_tensors=[output_tensor_grad])

input_tensor_grad = None
if input_tensor is not None:
if isinstance(input_tensor, tuple):
input_tensor_grad = tuple(
[t.grad for t in input_tensor if not t.stop_gradient])
with paddle.amp.auto_cast(enable=False):
if self.is_last_stage:
assert output_tensor_grad is None
if self.scaler:
paddle.autograd.backward(self.scaler.scale(output_tensor))
else:
paddle.autograd.backward(output_tensor)
else:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
if isinstance(output_tensor, tuple):
outputs = [t for t in output_tensor if not t.stop_gradient]
assert len(outputs) == len(output_tensor_grad)
paddle.autograd.backward(
tensors=outputs,
grad_tensors=[t for t in output_tensor_grad])
else:
paddle.autograd.backward(
tensors=[output_tensor],
grad_tensors=[output_tensor_grad])

input_tensor_grad = None
if input_tensor is not None:
if isinstance(input_tensor, tuple):
input_tensor_grad = tuple(
[t.grad for t in input_tensor if not t.stop_gradient])
else:
input_tensor_grad = input_tensor.grad
return input_tensor_grad

def _load_micro_batch(self, cache_id):
inputs = self.data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ...utils.log_util import logger
import numpy as np
from paddle import _C_ops
import paddle.fluid.core as core

_hcg = None
_use_cache = False
Expand Down Expand Up @@ -114,7 +115,7 @@ def _send_dims_shape_dtype(self, tensor, group):
paddle.distributed.send(stop_grad, dst=1, group=group)

def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor):
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
tensor_type = paddle.to_tensor([0])
# send tensor type
paddle.distributed.send(tensor_type, dst=1, group=group)
Expand All @@ -129,11 +130,11 @@ def send_meta(self, tensor, group):
paddle.distributed.send(nums, dst=1, group=group)

for d in tensor:
assert isinstance(d, paddle.Tensor)
assert isinstance(d, (paddle.Tensor, core.eager.Tensor))
self._send_dims_shape_dtype(d, group=group)

def set_send_message(self, tensor):
if isinstance(tensor, paddle.Tensor):
if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)):
self.send_shape_message = tensor.shape
self.send_dtype_message = paddle_2_number(tensor.dtype)
elif isinstance(tensor, tuple):
Expand Down
144 changes: 140 additions & 4 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import paddle
from paddle.fluid import core
from paddle import _C_ops
from paddle.autograd import PyLayer
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable
from ..parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode

__all__ = []

Expand Down Expand Up @@ -164,6 +165,138 @@ def _swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)


class _HPEagerRecomputeFunction(EagerPyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""

@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)

# store for recomputing
ctx.run_function = run_function

# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()

# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []

cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
outputs = run_function(*args)

for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)

ctx.save_for_backward(*tensor_inputs)

if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)

@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())

device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(tensors[i]).detach(
).reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]

tracer = framework._dygraph_tracer()
tracer._has_grad = True

# need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, core.eager.Tensor):
outputs = (outputs, )
assert len(outputs) == len(args)

forward_outputs_with_grad = []
backward_inputs = []

for i in range(len(outputs)):
if isinstance(
outputs[i],
core.eager.Tensor) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])

if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.eager.Tensor))
return grads


class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
Expand Down Expand Up @@ -290,8 +423,8 @@ def backward(ctx, *args):

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
return grads


Expand All @@ -303,7 +436,10 @@ def _hp_recompute(function, *args):
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor

all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, *args)
if in_dygraph_mode():
_HPEagerRecomputeFunction.apply(function, all_outputs, *args)
else:
_HPRecomputeFunction.apply(function, all_outputs, *args)

if len(all_outputs) == 1:
return all_outputs[0]
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 350)
set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 300)
set_tests_properties(test_parallel_dygraph_no_sync_gradient_check PROPERTIES TIMEOUT 30)
set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 500)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
Expand Down
9 changes: 5 additions & 4 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def test_pp_model(self):

with paddle.amp.auto_cast():
loss_a = model_a(img, label)
scaler_a.scale(loss_a).backward()
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()

scaler_a.scale(loss_a).backward()
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()

with paddle.amp.auto_cast():
loss_b = model_b.train_batch(
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def test_pp_model(self):

with paddle.amp.auto_cast(enable=True, level='O2'):
loss_a = model_a(img, label)
scaler_a.scale(loss_a).backward()
with paddle.amp.auto_cast(enable=False):
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()
scaler_a.scale(loss_a).backward()
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()

with paddle.amp.auto_cast(enable=True, level='O2'):
loss_b = model_b.train_batch(
[img, label], optimizer_b, scheduler_b, scaler=scaler_b)

Expand Down
Loading

0 comments on commit 7eae657

Please sign in to comment.