Skip to content

Commit

Permalink
Silu double grad (#53605)
Browse files Browse the repository at this point in the history
* add rules

* modify no kernel yaml parse

* success op generate

* success test_silu_double

* modify bug

* modify static error

* modify silu_grad input

* modify kernel signature

* modify kernel signature

* code style

* code style

* review

* delete opinfo modify
  • Loading branch information
xiaoguoguo626807 authored May 15, 2023
1 parent 0ef5180 commit 94c3880
Show file tree
Hide file tree
Showing 14 changed files with 529 additions and 357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"matmul_double_grad",
"tanh_double_grad",
"subtract_double_grad",
"silu_double_grad",
]

# dict of special api that forward api's output will affect bacward api's output
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/operators/generator/generate_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_base_op,
is_composite_op,
is_initializer_list,
is_only_composite_op,
is_scalar,
is_vec,
supports_inplace,
Expand Down Expand Up @@ -72,6 +73,7 @@
env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["only_composite_op"] = is_only_composite_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
Expand Down Expand Up @@ -165,6 +167,16 @@ def add_composite_info(ops, backward_ops, backward_op_dict):
else:
op["backward_composite"] = None

# add whether only composite
if (
op["backward_composite"] is not None
and "invoke" not in backward_op_dict[op["backward"]]
and "kernel" not in backward_op_dict[op["backward"]]
):
op["only_backward_composite"] = True
else:
op["only_backward_composite"] = False


# add fluid name in ops and backward ops info
def add_fluid_name(dict_list):
Expand Down Expand Up @@ -248,6 +260,9 @@ def update_common_params_name(
for param in op_item['invoke']['args'].split(',')
]
return
elif 'composite' in op_item and 'kernel' not in op_item:
return

op_item['infer_meta']['param'] = get_param_list_alias(
op_item['infer_meta']['param'], args_name_map
)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/generator/generate_sparse_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_base_op,
is_composite_op,
is_initializer_list,
is_only_composite_op,
is_scalar,
is_vec,
supports_inplace,
Expand Down Expand Up @@ -71,6 +72,7 @@
env.filters["get_infer_var_type_func"] = get_infer_var_type_func
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["only_composite_op"] = is_only_composite_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
Expand Down
17 changes: 11 additions & 6 deletions paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,12 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
"data_transform": data_trans,
}

# invokes another op ?
is_base_op = "invoke" not in op_entry
# op should be is_base_op or is_invoke_op or is_only_composite_op
is_base_op = True
if "invoke" in op_entry:
is_base_op = False
if "composite" in op_entry and "kernel" not in op_entry:
is_base_op = False

if is_base_op:
# kernel
Expand All @@ -524,10 +528,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
"inplace": inplace_pairs,
}
)
else:
# invoke
invoke = parse_invoke(op_name, op_entry["invoke"])
op["invoke"] = invoke

# has invoke ?
if "invoke" in op_entry:
invoke_dict = parse_invoke(op_name, op_entry["invoke"])
op.update({"invoke": invoke_dict})

# has composite ?
if "composite" in op_entry:
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/generator/templates/op.c.j2
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ using paddle::framework::GradVarName;
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}

{{operator(op)}}
{% elif op is only_composite_op %}

{% else %}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker,
{% endif %}
{% if "backward" in op and op["backward"] is not none %}{# backward #}
{% if "only_backward_composite" in op and op["only_backward_composite"] is true %}{# backward #}
{% elif "backward" in op and op["backward"] is not none %}
{% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/generator/tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def is_base_op(op):
return "kernel" in op and "infer_meta" in op


def is_only_composite_op(op):
return "composite" in op and "kernel" not in op and "invoke" not in op


def is_composite_op(op):
return "composite" in op

Expand Down
Loading

0 comments on commit 94c3880

Please sign in to comment.