Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted python-level trace_op to accomodate final state Eager Dygraph #39319

Merged
merged 27 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
126871b
Removed debug info
jim19930609 Jan 14, 2022
04b9bfb
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Jan 14, 2022
211a703
Added automatic code generation for final state Eager Dygraph
jim19930609 Jan 24, 2022
c2d832d
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Jan 24, 2022
c292c9f
Modified backward yaml
jim19930609 Jan 24, 2022
1d75522
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Jan 24, 2022
3723cab
Added EagerUtils helper functions for final state CodeGen
jim19930609 Jan 25, 2022
ca74350
Adjusted CMakeFiles to support compilation for final state auto gener…
jim19930609 Jan 25, 2022
62b1556
Added python-c code generation for final state Eager Dygraph
jim19930609 Jan 26, 2022
70f293d
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Jan 26, 2022
64e2421
Fixed minor issue
jim19930609 Jan 26, 2022
033482d
Fixed yaml.load() method failure
jim19930609 Jan 27, 2022
3b839b2
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Jan 28, 2022
fb7fcf6
Fixed minor issues
jim19930609 Jan 29, 2022
7b5eab7
Refactored Python-C Attributes Parsing Functions
jim19930609 Jan 29, 2022
e900a1e
Merge branch 'develop' into refactor_python_c_attrs
jim19930609 Jan 29, 2022
338fc1e
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Feb 7, 2022
25245ef
Fixed minor issue with Python-C AddFunctions
jim19930609 Feb 8, 2022
0371b28
Adjusted python-level trace_op to accomodate final state Eager Dygraph
jim19930609 Jan 28, 2022
3482e29
Added Logs for final state Eager Dygraph
jim19930609 Jan 29, 2022
eb8dd27
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Feb 8, 2022
e5266f9
Fixed merge issues
jim19930609 Feb 8, 2022
5c720cd
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
jim19930609 Feb 9, 2022
d32e3e7
Merge branch 'refactor_python_c_attrs' of /~https://github.com/jim19930…
jim19930609 Feb 9, 2022
c527965
Merged from development branch
jim19930609 Feb 9, 2022
d32f64b
Merged from develop
jim19930609 Feb 10, 2022
5ddd81a
Fixed minor issue
jim19930609 Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
import argparse
import os

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}


def ParseArguments():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -130,17 +136,16 @@ def ParseYamlArgs(string):
attrs_list = []

args = [x.strip() for x in string.strip().split(",")]

atype = r'((const )?\S+) '
aname = r'(\S+)'
aname = r'(.*)'
pattern = f'{atype}{aname}'
for i in range(len(args)):
arg = args[i]
m = re.search(pattern, arg)
arg_type = m.group(1)
arg_name = m.group(3).split("=")[0]
default_value = m.group(3).split("=")[1] if len(m.group(3).split(
"=")) > 1 else None
arg_type = m.group(1).strip()
arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
Expand Down Expand Up @@ -262,7 +267,6 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]

assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
Expand Down Expand Up @@ -741,26 +745,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
))
inputs_args_list = ["" for i in range(num_inputs)]
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
inputs_args_list[
inputs_args_definition_list[
pos] = f"const paddle::experimental::Tensor& {name}"
inputs_args_declaration_list[
pos] = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
inputs_args_list[
inputs_args_definition_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"

for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
if default_val is not None:
inputs_args_list[pos] = f"{atype} {name} = {default_val}"
inputs_args_declaration_list[
pos] = f"{atype} {name} = {default_val}"
else:
inputs_args_list[pos] = f"{atype} {name}"
inputs_args_declaration_list[pos] = f"{atype} {name}"
inputs_args_definition_list[pos] = f"{atype} {name}"

inputs_args_str = ", ".join(inputs_args_list)
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)

# Forward Full Logic
Expand Down Expand Up @@ -812,13 +824,95 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,

forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_str,
returns_type_str, forward_function_name, inputs_args_definition_str,
forward_call_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});"
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"

return forward_function_str, forward_function_declaration_str


def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list):
# fwd_api_name : ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())

final_state_fwd_api_name = "final_state_" + fwd_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor"
else:
assert IsVectorTensorType(ttype)
core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"

for name, _, _, pos in forward_attrs_list:
core_ops_args_info[final_state_fwd_api_name][pos] = name

for name, (ttype, pos) in forward_outputs_position_map.items():
core_ops_returns_info[final_state_fwd_api_name][pos] = name


def GenerateCoreOpInfoDeclaration():
core_ops_declaration_str = """
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;

"""
return core_ops_declaration_str


def GenerateCoreOpInfoDefinition():

CORE_OPS_INFO_TEMPLATE = """
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info = {{
{}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info = {{
{}
}};

"""
op_args_info_list = []
for op_name, arg_list in core_ops_args_info.items():
arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
op_args_info_list.append(op_args_info)

op_types_info_list = []
for op_name, type_list in core_ops_args_type_info.items():
type_str = ",".join(["\"" + v + "\"" for v in type_list])
op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
op_types_info_list.append(op_types_info)

op_returns_info_list = []
for op_name, return_list in core_ops_returns_info.items():
return_str = ",".join(["\"" + v + "\"" for v in return_list])
return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
op_returns_info_list.append(return_types_info)

op_args_info_str = "\n".join(op_args_info_list)
op_types_info_str = "\n".join(op_types_info_list)
op_returns_info_str = "\n".join(op_returns_info_list)

core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
op_args_info_str, op_types_info_str, op_returns_info_str)

return core_ops_info_definition_str


def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """
#include "glog/logging.h"
Expand Down Expand Up @@ -856,6 +950,8 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/utils/global_utils.h"

"""

file_contents += GenerateCoreOpInfoDefinition()
file_contents += forward_definition_str
with open(filepath, 'a') as f:
f.write(file_contents)
Expand All @@ -871,6 +967,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
#include "paddle/fluid/framework/op_registry.h"

"""
file_contents += GenerateCoreOpInfoDeclaration()
file_contents += forward_function_declaration_str
with open(filepath, 'a') as f:
f.write(file_contents)
Expand Down Expand Up @@ -985,6 +1082,11 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]

# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
forward_attrs_list)

# Generate Files
nodes_h_path = args.nodes_h_path
nodes_cc_path = args.nodes_cc_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
PyThreadState *tstate = nullptr;
try
{{
VLOG(6) << "Running Eager Final State API: {}";

// Get EagerTensors from args
{}

Expand All @@ -129,16 +131,87 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,

"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)

python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}},\n"

return python_c_function_str, python_c_function_reg_str


def GenerateCoreOpsInfoMap():
result = """
static PyObject * eager_get_final_state_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}

static PyObject * eager_get_final_state_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_args_type_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}

static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
return ToPyObject(core_ops_final_state_returns_info);
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
"""

core_ops_infos_registry = """
{\"get_final_state_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_info.\"},
{\"get_final_state_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_final_state_core_ops_args_type_info.\"},
{\"get_final_state_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_final_state_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_final_state_core_ops_returns_info.\"},
"""

return result, core_ops_infos_registry


def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):

core_ops_infos_definition, core_ops_infos_registry = GenerateCoreOpsInfoMap(
)

python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}"

PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once

Expand Down Expand Up @@ -215,12 +288,12 @@ def GeneratePythonCFile(filepath, python_c_str):
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)

python_c_function_reg_list.append("{nullptr,nullptr,0,nullptr}")
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)

python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)

print("Generated Python-C Codes: ", python_c_str)

output_path = args.output_path
Expand Down
Loading