From 264b757dc0cacebb7569931a4dc229de4a1b1166 Mon Sep 17 00:00:00 2001 From: veyron95 Date: Tue, 15 Feb 2022 03:35:16 +0000 Subject: [PATCH] Polish code, make it clear and simple --- .../test_imperative_hook_for_layer.py | 152 ++---------------- 1 file changed, 9 insertions(+), 143 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py index 0c238b4a45156a..4c457e9345c5d3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py @@ -30,9 +30,6 @@ call_forward_post_hook = False call_forward_pre_hook = False -call_forward_post_hook_eager = False -call_forward_pre_hook_eager = False - def forward_post_hook(layer, input, output): global call_forward_post_hook @@ -44,16 +41,6 @@ def forward_pre_hook(layer, input): call_forward_pre_hook = True -def eager_forward_post_hook(layer, input, output): - global call_forward_post_hook_eager - call_forward_post_hook_eager = True - - -def eager_forward_pre_hook(layer, input): - global call_forward_pre_hook_eager - call_forward_pre_hook_eager = True - - def forward_post_hook1(layer, input, output): return output * 2 @@ -65,7 +52,7 @@ def forward_pre_hook1(layer, input): class Test_Forward_Hook(unittest.TestCase): # test forward_pre_hook and forward_post_hook that have return value - def test_forward_hook_return_value(self): + def func_forward_hook_return_value(self): seed = 90 places = [fluid.CPUPlace()] @@ -118,7 +105,7 @@ def test_forward_hook_return_value(self): self.assertTrue( np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy())) - # register forward_post_hook + # register forward_posst_hook forward_post_hook_handle1 = simplenet.register_forward_post_hook( forward_post_hook1) outs_forward_hook = simplenet(input, y) @@ -133,72 +120,8 @@ def test_forward_hook_return_value(self): np.array_equal(outs_forward_hook.numpy(), outs_origin.numpy())) - for place in places: - with fluid.dygraph.guard(place): - with _test_eager_guard(): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - - input_word = np.array( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, - 8]).reshape(6, 3).astype('int64') - input_word1 = input_word * 2 - input_word = input_word.reshape((-1, 3, 1)) - input_word1 = input_word1.reshape((-1, 3, 1)) - y_data = np.array( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, - 9]).reshape(6, 3).astype('int64') - y_data = y_data.reshape((-1, 1)) - - input = base.to_variable(input_word) - input1 = base.to_variable(input_word1) - y = base.to_variable(y_data) - - simplenet = SimpleNet( - hidden_size=20, - vocab_size=32, - num_steps=3, - init_scale=0.1, - is_sparse=False, - dtype="float32") - - # origin, don't register any hook - outs_origin = simplenet(input, y) - outs_origin1 = simplenet(input1, y) - - # register forward_pre_hook - forward_pre_hook_handle1 = simplenet.register_forward_pre_hook( - forward_pre_hook1) - outs_pre_hook = simplenet(input, y) - self.assertTrue( - np.array_equal(outs_pre_hook.numpy(), - outs_origin1.numpy())) - - # remove forward_pre_hook - forward_pre_hook_handle1.remove() - outs_pre_hook = simplenet(input, y) - self.assertTrue( - np.array_equal(outs_pre_hook.numpy(), - outs_origin.numpy())) - - # register forward_post_hook - forward_post_hook_handle1 = simplenet.register_forward_post_hook( - forward_post_hook1) - outs_forward_hook = simplenet(input, y) - self.assertTrue( - np.array_equal(outs_forward_hook.numpy(), - outs_origin.numpy() * 2)) - - # remove forward_post_hook - forward_post_hook_handle1.remove() - outs_forward_hook = simplenet(input, y) - self.assertTrue( - np.array_equal(outs_forward_hook.numpy(), - outs_origin.numpy())) - # test forward_pre_hook and forward_post_hook that don't have return value - def test_forward_hook(self): + def func_forward_hook(self): seed = 90 places = [fluid.CPUPlace()] @@ -268,69 +191,12 @@ def test_forward_hook(self): self.assertFalse(call_forward_post_hook) self.assertFalse(call_forward_pre_hook) - for place in places: - with fluid.dygraph.guard(place): - with _test_eager_guard(): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - - global call_forward_post_hook_eager - global call_forward_pre_hook_eager - - input_word = np.array( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, - 8]).reshape(6, 3).astype('int64') - input_word = input_word.reshape((-1, 3, 1)) - y_data = np.array( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, - 9]).reshape(6, 3).astype('int64') - y_data = y_data.reshape((-1, 1)) - - input = base.to_variable(input_word) - y = base.to_variable(y_data) - - simplenet = SimpleNet( - hidden_size=20, - vocab_size=32, - num_steps=3, - init_scale=0.1, - is_sparse=False, - dtype="float32") - - # origin, don't register any hook - outs_origin = simplenet(input, y) - self.assertFalse(call_forward_post_hook_eager) - self.assertFalse(call_forward_pre_hook_eager) - - # register forward_post_hook and forward_pre_hook - forward_post_hook_handle = simplenet.register_forward_post_hook( - eager_forward_post_hook) - forward_pre_hook_handle = simplenet.register_forward_pre_hook( - eager_forward_pre_hook) - outs_hook = simplenet(input, y) - self.assertTrue(call_forward_post_hook_eager) - self.assertTrue(call_forward_pre_hook_eager) - - outs_hook = simplenet(input, y) - self.assertTrue(call_forward_post_hook_eager) - self.assertTrue(call_forward_pre_hook_eager) - - # remove forward_post_hook - forward_post_hook_handle.remove() - call_forward_post_hook_eager = False - call_forward_pre_hook_eager = False - outs_remove_forward_hook = simplenet(input, y) - self.assertFalse(call_forward_post_hook_eager) - self.assertTrue(call_forward_pre_hook_eager) - - # remove forward_pre_hook - forward_pre_hook_handle.remove() - call_forward_post_hook_eager = False - call_forward_pre_hook_eager = False - outs_remove_hook = simplenet(input, y) - self.assertFalse(call_forward_post_hook_eager) - self.assertFalse(call_forward_pre_hook_eager) + def test_forward_hook_return_value(self): + with _test_eager_guard(): + self.func_forward_hook() + self.func_forward_hook_return_value() + self.func_forward_hook() + self.func_forward_hook_return_value() if __name__ == '__main__':