forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[static code gen] support composite grad maker code gen #37
Merged
JiabinYang
merged 23 commits into
JiabinYang:prim_paddle
from
Charles-hit:static_composite_gen
Jan 8, 2023
Merged
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
0863e6b
support static graph code-gen for squeeze op
zyfncg b54bd12
generate static graph code of unsqueeze
zyfncg bb4d8ec
refine op name
zyfncg 6cd4f35
add extra output in op_compat
zyfncg bf1a2b1
remove debug log
zyfncg 9ed243f
add composite parse
Charles-hit ca65aad
Merge commit 'refs/pull/49430/head' of /~https://github.com/PaddlePaddl…
Charles-hit bf9f920
support generate static graph code for imag and real op
zyfncg 504d84b
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
zyfncg 5b22e09
Merge commit 'refs/pull/49523/head' of /~https://github.com/PaddlePaddl…
Charles-hit 043bd4b
Merge branch 'prim_paddle' of /~https://github.com/JiabinYang/Paddle in…
Charles-hit a40e5d8
Merge branch 'prim_paddle' of /~https://github.com/JiabinYang/Paddle in…
Charles-hit 44c902c
add composite code gen
Charles-hit 09e0192
modify backward yaml
Charles-hit c97ebba
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
Charles-hit 3ef725d
Merge branch 'prim_paddle' of /~https://github.com/JiabinYang/Paddle in…
Charles-hit 8ae8f0e
fix static composite grad maker code gen
Charles-hit 742c18e
add some static funcs unit test
Charles-hit e5f11f5
Merge branch 'prim_paddle' of /~https://github.com/JiabinYang/Paddle in…
Charles-hit bd7eded
fix some bugs
Charles-hit 0f2963d
Merge branch 'prim_paddle' of /~https://github.com/JiabinYang/Paddle in…
Charles-hit c2ff1a7
fix composite grad maker register code gen
Charles-hit caec995
optimize some functions
Charles-hit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
import yaml | ||
from filters import ( | ||
cartesian_prod_mapping, | ||
to_composite_grad_opmaker_name, | ||
to_input_name, | ||
to_int_array_tensor_name, | ||
to_int_array_tensors_name, | ||
|
@@ -32,6 +33,7 @@ | |
from parse_utils import to_named_dict | ||
from tests import ( | ||
is_base_op, | ||
is_composite_op, | ||
is_initializer_list, | ||
is_scalar, | ||
is_vec, | ||
|
@@ -57,7 +59,9 @@ | |
env.filters["to_input_name"] = to_input_name | ||
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr | ||
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping | ||
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name | ||
env.tests["base_op"] = is_base_op | ||
env.tests["composite_op"] = is_composite_op | ||
env.tests["vec"] = is_vec | ||
env.tests["scalar"] = is_scalar | ||
env.tests["initializer_list"] = is_initializer_list | ||
|
@@ -153,6 +157,27 @@ def process_int_array(op_item, int_array_configs): | |
] | ||
|
||
|
||
def parse_composite_info(ops, backward_ops, backward_op_dict): | ||
for op in ops: | ||
if "backward" in op: | ||
op["phi_backward"] = op["backward"] | ||
for backward_op in backward_ops: | ||
if "backward" in backward_op: | ||
backward_op["phi_backward"] = backward_op["backward"] | ||
for backward_op_name, op_dict in backward_op_dict.items(): | ||
if "composite" not in op_dict: | ||
continue | ||
op_dict["composite"]["phi_inputs"] = [] | ||
op_dict["composite"]["phi_attrs"] = [] | ||
op_dict["composite"]["phi_outputs"] = [] | ||
for input in op_dict["inputs"]: | ||
op_dict["composite"]["phi_inputs"].append(input['name']) | ||
for attr in op_dict["attrs"]: | ||
op_dict["composite"]["phi_attrs"].append(attr['name']) | ||
for output in op_dict["outputs"]: | ||
op_dict["composite"]["phi_outputs"].append(output['name']) | ||
|
||
|
||
# replace name of op and params for OpMaker | ||
def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): | ||
def get_phi_and_fluid_op_name(op_item): | ||
|
@@ -178,6 +203,37 @@ def update_grad_args_name(op_args, args_alias_map): | |
) | ||
item['name'] = args_alias_map[item['name'][:-5]] + '_grad' | ||
|
||
def add_fluid_info_in_composite(composite_map, args_alias_map): | ||
fluid_input_list = [] | ||
fluid_attr_list = [] | ||
fluid_output_list = [] | ||
# add fluid op inputs | ||
for input in composite_map["phi_inputs"]: | ||
if input in args_alias_map: | ||
fluid_input_list.append(args_alias_map[input]) | ||
else: | ||
fluid_input_list.append(input) | ||
# add fluid op attrs | ||
for attr in composite_map["phi_attrs"]: | ||
if attr in args_alias_map: | ||
fluid_attr_list.append(args_alias_map[attr]) | ||
else: | ||
fluid_attr_list.append(attr) | ||
# add fluid op outputs | ||
for output in composite_map["phi_outputs"]: | ||
if output in args_alias_map: | ||
fluid_output_list.append(args_alias_map[output]) | ||
else: | ||
fluid_output_list.append(output) | ||
|
||
composite_map.update( | ||
{ | ||
"fluid_inputs": fluid_input_list, | ||
"fluid_attrs": fluid_attr_list, | ||
"fluid_outputs": fluid_output_list, | ||
} | ||
) | ||
|
||
def get_param_list_alias(param_list, args_map): | ||
return [ | ||
args_map[param] if param in args_map else param | ||
|
@@ -307,6 +363,15 @@ def update_grad_op_compat_name(grad_op_item, args_name_map): | |
continue | ||
|
||
backward_op_list = op_args['backward'].split(',') | ||
# add fluid args name in composite map | ||
for backward_op in backward_op_list: | ||
if ( | ||
"composite" | ||
in backward_op_dict[backward_op.split('(')[0].strip()] | ||
): | ||
add_fluid_info_in_composite( | ||
backward_op_dict[backward_op]["composite"], args_map | ||
) | ||
_, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0]) | ||
forward_op_item['backward'] = bw_op_name | ||
backward_op_item['op_name'] = bw_op_name | ||
|
@@ -406,12 +471,10 @@ def main( | |
ops = yaml.safe_load(f) | ||
ops = [restruct_io(op) for op in ops] | ||
forward_op_dict = to_named_dict(ops) | ||
|
||
with open(backward_yaml_path, "rt") as f: | ||
backward_ops = yaml.safe_load(f) | ||
backward_ops = [restruct_io(op) for op in backward_ops] | ||
backward_op_dict = to_named_dict(backward_ops) | ||
|
||
with open(op_version_yaml_path, "rt") as f: | ||
op_versions = yaml.safe_load(f) | ||
# add op version info into op | ||
|
@@ -426,6 +489,8 @@ def main( | |
for bw_op in backward_ops: | ||
bw_op['op_name'] = bw_op['name'] | ||
|
||
parse_composite_info(ops, backward_ops, backward_op_dict) | ||
|
||
replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict) | ||
|
||
# prepare for invoke case | ||
|
@@ -442,21 +507,21 @@ def main( | |
op_dict = {} | ||
op_dict.update(forward_op_dict) | ||
op_dict.update(backward_op_dict) | ||
|
||
if len(ops) == 0 and len(backward_ops) == 0: | ||
if os.path.isfile(output_op_path): | ||
os.remove(output_op_path) | ||
if os.path.isfile(output_arg_map_path): | ||
os.remove(output_arg_map_path) | ||
return | ||
|
||
op_template = env.get_template('op.c.j2') | ||
with open(output_op_path, "wt") as f: | ||
msg = op_template.render( | ||
ops=ops, backward_ops=backward_ops, op_dict=op_dict | ||
ops=ops, | ||
backward_ops=backward_ops, | ||
op_dict=op_dict, | ||
composite_gen_flag=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add TODO here to support all static code gen with composite |
||
) | ||
f.write(msg) | ||
|
||
ks_template = env.get_template('ks.c.j2') | ||
with open(output_arg_map_path, 'wt') as f: | ||
msg = ks_template.render(ops=ops, backward_ops=backward_ops) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this? This is duplicated with line 159