Skip to content

Commit

Permalink
[Dy2St]Get grad names when call append backward to fix high order gra…
Browse files Browse the repository at this point in the history
…dient (PaddlePaddle#53250)

[Dy2St]Get grad names when call append backward to fix high order gradient (PaddlePaddle#53250)
  • Loading branch information
0x45f authored and 0x45f committed Apr 27, 2023
1 parent b699659 commit 723ceef
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 111 deletions.
19 changes: 15 additions & 4 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,15 +689,26 @@ class GradNodeRunProgram : public egr::GradNodeBase {
protected:
void ConstructXGradTensors(const std::vector<paddle::Tensor> &x,
std::vector<paddle::Tensor> *x_grad) {
auto x_grad_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs_.at("x_grad_names"));
PADDLE_ENFORCE_EQ(
x.size(),
x_grad_names.size(),
paddle::platform::errors::InvalidArgument(
"The x.size() and x_grad_names.size() should be equal. "
"But received x.size() = %d, x_grad_names.size() = %d",
x.size(),
x_grad_names.size()));

// TODO(dev): Need an elegant way to determine inforamtion of grad_tensor,
// such as: name, tensor type(DenseTensor or SelectedRows).
for (auto &t : x) {
if (t.is_dense_tensor()) {
for (size_t i = 0; i < x.size(); i++) {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (t.is_selected_rows()) {
} else if (x[i].is_selected_rows()) {
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(t.name() + "@GRAD");
x_grad->back().set_name(x_grad_names[i]);
}
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/run_program_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"std::vector<std::string>"
"The names of output gradients.")
.SetDefault({});
AddAttr<std::vector<std::string>>("x_grad_names",
"std::vector<std::string>"
"The names of input gradients.")
.SetDefault({});
AddComment(R"DOC(
RunProgram operator.
Expand Down
69 changes: 47 additions & 22 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,28 +2374,12 @@ def _find_op_path_(
return op_path


def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
"""
Backpropagate the gradients of targets to inputs.
Args:
targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors
inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors
target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All Tensors with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
Default: None.
Return:
(list[Tensor]): A list of gradients for inputs
If an input does not affect targets, the corresponding gradient Tensor
will be None
"""
def calc_gradient_helper(
targets, inputs, target_gradients=None, no_grad_set=None
):
'''
Calculate gradient and return grad_info_map
'''
targets = _as_list(targets)
inputs = _as_list(inputs)
target_gradients = _as_list(target_gradients)
Expand Down Expand Up @@ -2508,7 +2492,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):

_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
prog._sync_with_cpp()
return grad_info_map


def _get_grad_vars(grad_info_map, inputs):
inputs = _as_list(inputs)
grad_vars = []
for input_var in inputs:
if input_var.name not in grad_info_map:
Expand All @@ -2518,6 +2506,43 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
grad_block = grad_info[1]
grad_var = grad_block.var(grad_info[0])
grad_vars.append(grad_var)
return grad_vars


def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
"""
Backpropagate the gradients of targets to inputs.
Args:
targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors
inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors
target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All Tensors with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
Default: None.
Return:
(list[Tensor]): A list of gradients for inputs
If an input does not affect targets, the corresponding gradient Tensor
will be None
"""

# NOTE: If you want to modify the logic of calc_gradient, please modify
# it inside the calc_gradient_helper and _get_grad_vars functions
# to ensure the correctness of dy2st mode.
grad_info_map = calc_gradient_helper(
targets,
inputs,
target_gradients=target_gradients,
no_grad_set=no_grad_set,
)

grad_vars = _get_grad_vars(grad_info_map, inputs)

if len(grad_vars) == 1:
return grad_vars[0]
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/tests/unittests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_check_output(self):
self.check_output(check_prim=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', check_prim=False)


class TestDropoutOpInput1d(OpTest):
Expand All @@ -107,7 +108,8 @@ def test_check_output(self):
self.check_output(check_prim=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', check_prim=False)


class TestDropoutOp2(TestDropoutOp):
Expand Down Expand Up @@ -283,7 +285,8 @@ def test_check_output(self):
self.check_output(check_prim=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.05, check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', max_relative_error=0.05, check_prim=False)


@unittest.skipIf(
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_eager_run_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def test_eager(self):
['Fake_var@GRAD'],
'out_grad_names',
[out.name + '@GRAD'],
'x_grad_names',
[x_t.name + '@GRAD', y_t.name + '@GRAD'],
]

use_interpretorcore = True
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/tests/unittests/test_run_program_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def calc_dygraph_output(self, place):
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
)
)

Expand Down Expand Up @@ -303,6 +305,8 @@ def calc_dygraph_grad(self, place):
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
)
)

Expand Down
58 changes: 37 additions & 21 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.optimizer.lr import LRScheduler

from . import logging_utils
from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
_out_grad_names,
_param_grad_names,
backend_guard,
construct_grad_names,
)

__all__ = []
Expand Down Expand Up @@ -208,6 +207,7 @@ def __init__(
self._scope_cache = {}
self._hooker = None
self._backend = kwargs.get('backend', None)
self._grad_var_names = {}

def __call__(self, inputs):
"""
Expand Down Expand Up @@ -443,23 +443,11 @@ def _train_pure_fp16_program_id(self):
def _infer_pure_fp16_program_id(self):
return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)

@LazyInitialized
def _param_grad_names(self):
return _param_grad_names(self._train_program.desc, self._params)

def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[
paddle.utils._hash_with_id(program, self)
]

@LazyInitialized
def _out_grad_names(self):
return _out_grad_names(
self._train_program.desc,
self.get_forward_end_op_idx(self._train_program),
len(self._outputs.var_ids),
)

@property
def program(self):
"""
Expand Down Expand Up @@ -649,7 +637,33 @@ def _append_backward_desc(self, main_program):
if targets:
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
with backend_guard(self._backend):
backward.gradients(targets=targets, inputs=[])
check_type(
targets,
'targets',
(framework.Variable, list, tuple),
'paddle.static.gradients',
)
grad_info_map = backward.calc_gradient_helper(
targets=targets, inputs=[]
)

x_vars = [
program.block(0).var(var.name)
for var in self._inputs
if isinstance(var, framework.Variable)
]
param_vars = [
program.block(0).var(param.name) for param in self._params
]
out_vars = [
program.block(0).var(var.name)
for var in self._outputs
if isinstance(var, framework.Variable)
]

self._grad_var_names = construct_grad_names(
grad_info_map, x_vars, param_vars, out_vars
)

if self._hooker:
program, start_idx = self._hooker.after_append_backward(
Expand Down Expand Up @@ -720,9 +734,11 @@ def _prepare_attributes(self):
attrs.extend(
(
'param_grad_names',
self._param_grad_names,
self._grad_var_names.get('param', []),
'out_grad_names',
self._out_grad_names,
self._grad_var_names.get('out', []),
'x_grad_names',
self._grad_var_names.get('x', []),
)
)
if self._cuda_graph_capture_mode:
Expand Down Expand Up @@ -761,9 +777,9 @@ def _get_forward_backward_program_form(
backward_end_op_index = whole_program.desc.block(0).op_size()
# For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
# they will be shared in scope and used by optimizer.
backward_skip_vars = (
self._parse_skip_gc_vars(whole_program) + self._param_grad_names
)
backward_skip_vars = self._parse_skip_gc_vars(
whole_program
) + self._grad_var_names.get('param', [])
backward_builded_program = add_build_strategy_for(
whole_program,
backward_start_op_index,
Expand Down
73 changes: 17 additions & 56 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import paddle
from paddle import fluid # noqa: F401
from paddle.fluid import core, unique_name
from paddle.fluid import backward, core, framework, unique_name
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
Expand Down Expand Up @@ -1457,61 +1457,6 @@ def create_name_str(name_ids):
return "(%s, )" % ','.join(names_str)


def _param_grad_names(program_desc, params):
"""
Parse PARAM@GARD name from original train and infer program.
"""
names = []
# NOTE: `names` and `params` must be in the same order so that
# the param grad name can be set correctly in the run_program.
for param in params:
candidate = []
for var in program_desc.block(0).all_vars():
var_name = var.name()
if param.name not in var_name:
continue
suf_count = var_name.count(GRAD_SUFFIX)
if suf_count > 0:
suffix = param.name + GRAD_SUFFIX * suf_count
pre_count = var_name.count(GRAD_PREFIX)
if GRAD_PREFIX * pre_count + suffix == var_name:
candidate.append(var_name)

if candidate:
names.append(
max(
candidate,
key=lambda name: name.count(GRAD_PREFIX)
if GRAD_PREFIX in name
else name.count(GRAD_SUFFIX),
)
)
else:
names.append(param.name + GRAD_SUFFIX)
return names


def _out_grad_names(program_desc, fwd_end_op_index, out_size):
"""
Parse Out@GARD name from original train and infer program.
"""
names = []
for i in range(
fwd_end_op_index,
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
# If prim forward op, fill_any_like will be decomposite as fill_constant.
if core._is_fwd_prim_enabled():
target = ('fill_any_like', 'fill_constant')
else:
target = 'fill_any_like'
if op.type() in target:
var_name = op.output('Out')[0]
names.append(var_name)
return names


def prim_or_cinn_is_enabled(build_strategy, backend):
if backend == 'CINN':
return True
Expand Down Expand Up @@ -1566,3 +1511,19 @@ def backend_guard(backend):
finally:
core._set_prim_forward_enabled(orign_fwd)
core._set_prim_backward_enabled(orign_bwd)


def construct_grad_names(grad_info_map, x_vars, param_vars, out_vars):
grad_var_names = {}
fn = (
lambda grad_var: grad_var.name
if isinstance(grad_var, framework.Variable)
else framework.EMPTY_VAR_NAME
)
x_grad_vars = backward._get_grad_vars(grad_info_map, x_vars)
grad_var_names['x'] = list(map(fn, x_grad_vars))
param_grad_vars = backward._get_grad_vars(grad_info_map, param_vars)
grad_var_names['param'] = list(map(fn, param_grad_vars))
out_grad_vars = backward._get_grad_vars(grad_info_map, out_vars)
grad_var_names['out'] = list(map(fn, out_grad_vars))
return grad_var_names
Loading

0 comments on commit 723ceef

Please sign in to comment.