-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dygraph Prim code gen (node.cc) #33
Changes from 4 commits
9840634
34c6fe8
dea7152
247d207
d55fe36
7a1e172
c951ebe
eaece38
b4fcb0a
bc81c92
f91bfd8
7d37d4b
dcfa5b6
11f6786
bd51eb1
9085192
dd68035
8553b59
858ec94
dc3fc79
bca447b
fa3b1cf
2e00145
a5705e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -332,6 +332,9 @@ class {} : public egr::GradNodeBase {{ | |
#include "paddle/fluid/eager/nan_inf_utils.h" | ||
#include "paddle/phi/api/include/sparse_api.h" | ||
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" | ||
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" | ||
#include "paddle/fluid/prim/api/all.h" | ||
#include "paddle/fluid/prim/utils/utils.h" | ||
DECLARE_bool(check_nan_inf); | ||
{} | ||
""" | ||
|
@@ -546,6 +549,7 @@ def __init__( | |
# self.forward_outputs_position_map | ||
# self.optional_inputs | ||
# self.no_need_buffers | ||
# self.composite_func_info | ||
# self.intermediate_outputs | ||
# self.forward_inplace_map | ||
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace) | ||
|
@@ -871,6 +875,7 @@ def GenerateNodeCreationCodes(self, for_backward=False): | |
backward_grad_outputs_map = self.backward_grad_outputs_map | ||
backward_attrs_list = self.backward_attrs_list | ||
optional_inputs = self.optional_inputs | ||
is_composite_forward_api = False if self.composite_func_info == [] else True | ||
|
||
# Pass Stop Gradient Args | ||
pass_stop_gradient_args_str = self.GetPassStopGradientArgsList( | ||
|
@@ -1056,6 +1061,8 @@ def run(self): | |
self.ParseBackwardInplaceInfo() | ||
# Parse no_need_buffer | ||
self.ParseNoNeedBuffer() | ||
# Parse composite | ||
self.ParseComposite() | ||
|
||
# Parse optional_inputs | ||
self.ParseDispensable() | ||
|
@@ -1826,16 +1833,19 @@ def GenerateHigherOrderNodeCreationCode(self): | |
is_invoke_forward_api = IsInvokeForwardApi( | ||
self.grad_api_contents, self.forward_apis_dict | ||
) | ||
is_composite_forward_api = False if self.composite_func_info == [] else True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why composite forward api? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace by composite grad api |
||
|
||
if next_node_generator is not None: | ||
has_higher_order_node = True | ||
return ( | ||
has_higher_order_node, | ||
is_invoke_forward_api, | ||
is_composite_forward_api, | ||
next_grad_node_creation_str, | ||
next_grad_node_out_list, | ||
next_node_generator.backward_forward_inputs_map, | ||
) | ||
elif not is_invoke_forward_api: | ||
elif not is_invoke_forward_api and not is_composite_forward_api: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if it hits else branch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if is_invoke_farward_api or is_composite_grad_api next_grad_node_creation_str should be none, we will add this when delete Flags_prim_enabled |
||
next_grad_node_creation_str = f""" if(trace_backward) {{ | ||
PADDLE_THROW(phi::errors::Unavailable( | ||
\"The Op {self.backward_api_name} doesn't have any grad\" | ||
|
@@ -1845,6 +1855,7 @@ def GenerateHigherOrderNodeCreationCode(self): | |
return ( | ||
has_higher_order_node, | ||
is_invoke_forward_api, | ||
is_composite_forward_api, | ||
next_grad_node_creation_str, | ||
next_grad_node_out_list, | ||
None, | ||
|
@@ -1942,13 +1953,15 @@ def GenerateNodeDefinition( | |
self, | ||
has_higher_order_node, | ||
is_invoke_forward_api, | ||
is_composite_grad_api, | ||
next_grad_node_creation_str, | ||
next_grad_node_out_list, | ||
backward_forward_inputs_map_next, | ||
): | ||
namespace = self.namespace | ||
forward_api_name = self.forward_api_name | ||
backward_api_name = self.backward_api_name | ||
composite_backward_api_name = self.composite_func_info[0] if is_composite_grad_api else None | ||
backward_forward_inputs_map = self.backward_forward_inputs_map | ||
backward_grad_inputs_map = self.backward_grad_inputs_map | ||
backward_grad_outputs_map = self.backward_grad_outputs_map | ||
|
@@ -2133,6 +2146,7 @@ def GenerateNodeDefinition( | |
# Grad Function Call String | ||
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) | ||
grad_api_namespace = f"paddle::experimental::{namespace}" | ||
composite_grad_api_namespace = f"paddle::prim::{namespace}" | ||
grad_function_prepare_str = f""" | ||
const auto& out_metas = OutputMeta(); | ||
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs}); | ||
|
@@ -2203,6 +2217,8 @@ def GenerateNodeDefinition( | |
}}""" | ||
|
||
grad_api_args_str = ", ".join(grad_api_args) | ||
composite_grad_api_args_str = ", ".join(grad_api_args) | ||
composite_template_name = "<paddle::experimental::Tensor>" | ||
|
||
if is_invoke_forward_api: | ||
autograd_api_out = "auto" | ||
|
@@ -2225,6 +2241,16 @@ def GenerateNodeDefinition( | |
{out_assign_str}}} else {{ | ||
{indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']}; | ||
{out_assign_str}{indent}}} | ||
""" | ||
elif is_composite_grad_api: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave TODO here to indicate strategy which will be used later, such as using composite only when we don't have backward kernel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be statically generated here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
grad_function_call_str = f""" | ||
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ | ||
{indent}{composite_grad_api_namespace}{composite_backward_api_name}{composite_template_name}({composite_grad_api_args_str}); | ||
VLOG(4) << paddle::string::Sprintf("composite api %s is called" , "{composite_backward_api_name}"); | ||
}}else{{ | ||
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); | ||
VLOG(4) << paddle::string::Sprintf("origin api %s is called" , "{backward_api_name}"); | ||
}} | ||
""" | ||
else: | ||
grad_function_call_str = f""" | ||
|
@@ -2328,6 +2354,9 @@ def GenerateNodeDefinition( | |
var_str += f"\n{indent} output_str += output_{new_name}_str; " | ||
|
||
log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str) | ||
# TODO Ruting modify in the future | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO with wrong format There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. modified |
||
# if is_composite_forward_api: | ||
# next_grad_node_creation_str = '' | ||
|
||
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format( | ||
grad_node_name, | ||
|
@@ -2361,6 +2390,7 @@ def run(self): | |
( | ||
has_higher_order_node, | ||
is_invoke_forward_api, | ||
is_composite_grad_api, | ||
next_grad_node_creation_str, | ||
next_grad_node_out_list, | ||
backward_forward_inputs_map, | ||
|
@@ -2371,6 +2401,7 @@ def run(self): | |
self.GenerateNodeDefinition( | ||
has_higher_order_node, | ||
is_invoke_forward_api, | ||
is_composite_grad_api, | ||
next_grad_node_creation_str, | ||
next_grad_node_out_list, | ||
backward_forward_inputs_map, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove additional generated files