Skip to content

Commit

Permalink
Polish code, make it clear and simple
Browse files Browse the repository at this point in the history
  • Loading branch information
veyron95 committed Feb 15, 2022
1 parent 47b507a commit 264b757
Showing 1 changed file with 9 additions and 143 deletions.
152 changes: 9 additions & 143 deletions python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
call_forward_post_hook = False
call_forward_pre_hook = False

call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False


def forward_post_hook(layer, input, output):
global call_forward_post_hook
Expand All @@ -44,16 +41,6 @@ def forward_pre_hook(layer, input):
call_forward_pre_hook = True


def eager_forward_post_hook(layer, input, output):
global call_forward_post_hook_eager
call_forward_post_hook_eager = True


def eager_forward_pre_hook(layer, input):
global call_forward_pre_hook_eager
call_forward_pre_hook_eager = True


def forward_post_hook1(layer, input, output):
return output * 2

Expand All @@ -65,7 +52,7 @@ def forward_pre_hook1(layer, input):

class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_post_hook that have return value
def test_forward_hook_return_value(self):
def func_forward_hook_return_value(self):
seed = 90

places = [fluid.CPUPlace()]
Expand Down Expand Up @@ -118,7 +105,7 @@ def test_forward_hook_return_value(self):
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))

# register forward_post_hook
# register forward_posst_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
Expand All @@ -133,72 +120,8 @@ def test_forward_hook_return_value(self):
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))

for place in places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word1 = input_word * 2
input_word = input_word.reshape((-1, 3, 1))
input_word1 = input_word1.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))

input = base.to_variable(input_word)
input1 = base.to_variable(input_word1)
y = base.to_variable(y_data)

simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")

# origin, don't register any hook
outs_origin = simplenet(input, y)
outs_origin1 = simplenet(input1, y)

# register forward_pre_hook
forward_pre_hook_handle1 = simplenet.register_forward_pre_hook(
forward_pre_hook1)
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(),
outs_origin1.numpy()))

# remove forward_pre_hook
forward_pre_hook_handle1.remove()
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(),
outs_origin.numpy()))

# register forward_post_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))

# remove forward_post_hook
forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))

# test forward_pre_hook and forward_post_hook that don't have return value
def test_forward_hook(self):
def func_forward_hook(self):
seed = 90

places = [fluid.CPUPlace()]
Expand Down Expand Up @@ -268,69 +191,12 @@ def test_forward_hook(self):
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)

for place in places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

global call_forward_post_hook_eager
global call_forward_pre_hook_eager

input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word = input_word.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))

input = base.to_variable(input_word)
y = base.to_variable(y_data)

simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")

# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertFalse(call_forward_pre_hook_eager)

# register forward_post_hook and forward_pre_hook
forward_post_hook_handle = simplenet.register_forward_post_hook(
eager_forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
eager_forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

outs_hook = simplenet(input, y)
self.assertTrue(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

# remove forward_post_hook
forward_post_hook_handle.remove()
call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertFalse(call_forward_pre_hook_eager)
def test_forward_hook_return_value(self):
with _test_eager_guard():
self.func_forward_hook()
self.func_forward_hook_return_value()
self.func_forward_hook()
self.func_forward_hook_return_value()


if __name__ == '__main__':
Expand Down

1 comment on commit 264b757

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 264b757 Feb 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #39531 Commit ID: 264b757 contains failed CI.

🔹 Failed: PR-CI-MLU

Unknown Failed
2022-02-15 16:53:45 [ 23%] Built target fusion_gru_op
2022-02-15 16:53:45 [ 23%] Built target fused_embedding_fc_lstm_op
2022-02-15 16:53:45 [ 23%] Built target multi_gru_op
2022-02-15 16:53:46 [ 23%] Built target elementwise_max_op
2022-02-15 16:53:51 [ 23%] Built target mine_hard_examples_op
2022-02-15 16:53:51 Makefile:140: recipe for target 'all' failed
2022-02-15 16:53:51 make: *** [all] Error 2
2022-02-15 16:53:51 + build_error=2
2022-02-15 16:53:51 + collect_ccache_hits
2022-02-15 16:53:51 ++ ccache -s
2022-02-15 16:53:51 ++ grep 'cache hit rate'
2022-02-15 16:53:51 ++ awk '{print $4}'
2022-02-15 16:53:54 + rate=96.21
2022-02-15 16:53:54 + echo 'ccache hit rate: 96.21%'
2022-02-15 16:53:54 ccache hit rate: 96.21%
2022-02-15 16:53:54 + echo 'ipipe_log_param_Ccache_Hit_Rate: 96.21%'
2022-02-15 16:53:54 + '[' 2 '!=' 0 ']'
2022-02-15 16:53:54 + exit 7
2022-02-15 16:53:54 {build code state=7}

Please sign in to comment.