Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Eager Hook] Support eager hook_for_layer #39531

Merged
merged 5 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def register_forward_pre_hook(self, hook):
import paddle
import numpy as np

# the forward_post_hook change the input of the layer: input = input * 2
# the forward_pre_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,22 +25,23 @@
import paddle.fluid.dygraph.base as base

from test_imperative_lod_tensor_to_selected_rows import SimpleNet
from paddle.fluid.framework import _test_eager_guard

call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False


def forward_hook(layer, input, output):
global call_forward_hook
call_forward_hook = True
def forward_post_hook(layer, input, output):
global call_forward_post_hook
call_forward_post_hook = True


def forward_pre_hook(layer, input):
global call_forward_pre_hook
call_forward_pre_hook = True


def forward_hook1(layer, input, output):
def forward_post_hook1(layer, input, output):
return output * 2


Expand All @@ -50,8 +51,8 @@ def forward_pre_hook1(layer, input):


class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value
def test_forward_hook_return_value(self):
# test forward_pre_hook and forward_post_hook that have return value
def func_forward_hook_return_value(self):
seed = 90

places = [fluid.CPUPlace()]
Expand Down Expand Up @@ -104,23 +105,23 @@ def test_forward_hook_return_value(self):
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))

# register forward_hook
forward_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1)
# register forward_posst_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))

# remove forward_hook
forward_hook_handle1.remove()
# remove forward_post_hook
forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))

# test forward_pre_hook and forward_hook that don't have return value
def test_forward_hook(self):
# test forward_pre_hook and forward_post_hook that don't have return value
def func_forward_hook(self):
seed = 90

places = [fluid.CPUPlace()]
Expand All @@ -133,7 +134,7 @@ def test_forward_hook(self):
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

global call_forward_hook
global call_forward_post_hook
global call_forward_pre_hook

input_word = np.array(
Expand All @@ -158,38 +159,45 @@ def test_forward_hook(self):

# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)

# register forward_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook(
forward_hook)
# register forward_post_hook and forward_pre_hook
forward_post_hook_handle = simplenet.register_forward_post_hook(
forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

# remove forward_hook
forward_hook_handle.remove()
call_forward_hook = False
# remove forward_post_hook
forward_post_hook_handle.remove()
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)

def test_forward_hook_return_value(self):
with _test_eager_guard():
self.func_forward_hook()
self.func_forward_hook_return_value()
self.func_forward_hook()
self.func_forward_hook_return_value()


if __name__ == '__main__':
unittest.main()