Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.91】 #52948

Merged
merged 5 commits into from
Apr 27, 2023
Merged

Conversation

yangguohao
Copy link
Contributor

@yangguohao yangguohao commented Apr 15, 2023

PR types

Others

PR changes

Others

Description

register_hook for static mode

@paddle-bot
Copy link

paddle-bot bot commented Apr 15, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 15, 2023
@yangguohao
Copy link
Contributor Author

yangguohao commented Apr 17, 2023

通过将函数转为 AST 改写代码使得 动转静 下的 register_hook 能成功运行。

之后还可以提升的地方

  1. 如果 register_hook 存在于一个内部的函数中,这种情况没有办法考虑到,如下面的代码所示
@to_static
def f(x):
    def hook(g):
        return g * 2
    def inner_function():
        x.register_hook(hook)
    inner_function()
  1. 没有考虑实现 hook.remove 的方法
  2. 改写后的代码的行数与源代码不一致,例如 注释,多余空行等,会被自动优化掉,可能会出现不可预料的报错,或者信息不对等问题。

@Ligoml Ligoml requested a review from Aurelius84 April 19, 2023 03:03
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非常感谢您的贡献。这里方案有几个问题待商榷:

  1. PR支持了被to_static装饰下使用register_hook的场景,对于继承了nn.Layer的class,如果register_hook函数是self.__init__函数中调用的,则被@to_static装饰的forward函数里registerc_hook是否会正确触发呢?

@@ -650,6 +650,10 @@ def func_to_source_code(function, dedent=True):
for line in source_code_list
]
source_code = ''.join(source_code_list)
# check the 'register hook' in the source code
if 'register_hook' in source_code:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是一个AST层面的变换,推荐写成一个单独的AstTransformer。因为source_code层面的str匹配是有风险的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 NodeVisitor 来进行遍历修改节点

@@ -38,3 +41,69 @@ def pretty_source(source):

source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code


def modify_function_code(func, code_str='register_hook'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

几点建议哈:

  1. 建议将其以AstTransformer形式接入进来,可以参考dy2stat目录下其他Transformer
  2. 注释建议使用英文
  3. 移除不必要的 print等注释代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释改为英文、移除 print 等不必要的注释代码

"""do nothing but return a new variable."""
return x

# class HookRemoveHelper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议移除注释代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@yangguohao
Copy link
Contributor Author

非常感谢您的贡献。这里方案有几个问题待商榷:

  1. PR支持了被to_static装饰下使用register_hook的场景,对于继承了nn.Layer的class,如果register_hook函数是self.__init__函数中调用的,则被@to_static装饰的forward函数里registerc_hook是否会正确触发呢?

是类似之前的这个 RFC 提到的 LinearNet 的例子吗,是可以被正确的触发的。

@Aurelius84
Copy link
Contributor

是的,RFC里是对params注册了hook,是否可以在test/unittest/dygraph_to_static/ 目录下添加一个test_tensor_hook.py,丰富下不同使用场景的单测单元case?

@yangguohao
Copy link
Contributor Author

增加了 test_tensor_hook.py 的测试代码,设定了几种形式,包括了在 Layer 的 forward 中对参数 register_hook,这里我把 varbase_patch_method 中 monkey_patch_varbase 中 register_hook 的 dygraph_only 装饰器注释了,发现代码都可以直接运行。

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for overall


def backward_hook_wrapper(dy):
"""call the backward hook in ."""
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在函数内添加import numpy ,是因为此函数会在py_func里执行,且用户可能没有在模型代码里添加import ?

@@ -38,3 +41,84 @@ def pretty_source(source):

source_code = astor.to_source(ast_node, pretty_source=pretty_source)
return source_code


class RegisterHookVisitor(gast.NodeVisitor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如下 comment 可以考虑下个PR做优化。

  1. ast_utils.py 为公共文件,所有上层的Ast Transformer变换逻辑按照规范是独立一个文件的。
  2. 所有的AST 变换建议都继承 BaseTransformer,并提供def transform(self): 方法,因为报错栈回溯逻辑是在基类BaseTransformer中实现的。

func_def.body = new_body


def modify_function_code(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样下个PR可以优化:

  1. 动转静的AST Transfomer 逻辑是放在 ast_transformer.py 统一生效的,里面有一个list,会逐个应用生效

if dedent:
source_code = textwrap.dedent(source_code)
# return modified function source code if there is 'register_hook', otherwise return None
source_code = modify_function_code(function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如前面的comment,ast_to_func函数的职责只负责借助 AST 生成python module,因此建议下个PR可否将此处逻辑放到 ast_transformer.py

loss_jit = jit_layer(image_jit)
loss_jit.backward()
loss.backward()
self.assertTrue(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外从框架内部单测规范上,我们更建议使用类似 np.testing.assert_allclose() 函数,而非self.assertTrue(xxx),详见:/~https://github.com/PaddlePaddle/community/blob/master/rfcs/CodeStyle/20220805_code_style_improvement_for_unittest.md#background

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for framework.py

@luotao1 luotao1 merged commit db30aa1 into PaddlePaddle:develop Apr 27, 2023
@luotao1
Copy link
Contributor

luotao1 commented Apr 27, 2023

备注:黑客松91题还未完成,仍需要PR优化

@yangguohao
Copy link
Contributor Author

yangguohao commented May 7, 2023

Hi @Aurelius84, check this new PR #53572 for updates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants