-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【prim】add dropout composite rule #50497
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import parameterized as param | ||
|
||
import paddle | ||
from paddle.fluid import core | ||
|
||
np.random.seed(2023) | ||
|
||
|
||
place = ( | ||
paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() | ||
) | ||
|
||
|
||
@param.parameterized_class( | ||
('name', 'x', 'p', 'is_test', 'mode', 'seed', 'dtype', 'place'), | ||
( | ||
( | ||
'fp32', | ||
np.random.rand(100000), | ||
0.3, | ||
False, | ||
'upscale_in_train', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'fp64', | ||
np.random.rand(100000), | ||
0.7, | ||
False, | ||
'upscale_in_train', | ||
9999, | ||
'float64', | ||
place, | ||
), | ||
( | ||
'is_test=True', | ||
np.random.rand(100000), | ||
0.5, | ||
True, | ||
'upscale_in_train', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'p=1.0', | ||
np.random.rand(100000), | ||
1.0, | ||
True, | ||
'upscale_in_train', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'p=1.0,test=False', | ||
np.random.rand(100000), | ||
1.0, | ||
False, | ||
'upscale_in_train', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'p=0.0', | ||
np.random.rand(100000), | ||
1.0, | ||
True, | ||
'upscale_in_train', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'downgrade_train', | ||
np.random.rand(100000), | ||
0.5, | ||
False, | ||
'downscale_in_infer', | ||
1002, | ||
'float32', | ||
place, | ||
), | ||
( | ||
'fp32_cpu', | ||
np.random.rand(100000), | ||
0.6, | ||
False, | ||
'upscale_in_train', | ||
9899, | ||
'float64', | ||
paddle.CPUPlace(), | ||
), | ||
( | ||
'fp64_cpu', | ||
np.random.rand(100000), | ||
0.6, | ||
False, | ||
'upscale_in_train', | ||
9899, | ||
'float64', | ||
paddle.CPUPlace(), | ||
), | ||
( | ||
'downgrade_train_cpu', | ||
np.random.rand(100000), | ||
0.5, | ||
False, | ||
'downscale_in_infer', | ||
1002, | ||
'float32', | ||
paddle.CPUPlace(), | ||
), | ||
), | ||
) | ||
class TestCompositeDropout(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
paddle.enable_static() | ||
cls.x = cls.x.astype(cls.dtype) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
paddle.disable_static() | ||
|
||
def test_comp(self): | ||
def dropout(x, p, is_test, mode, seed=0): | ||
paddle.seed(seed) | ||
mp, sp = paddle.static.Program(), paddle.static.Program() | ||
with paddle.static.program_guard(mp, sp): | ||
input_ = paddle.static.data('x', shape=x.shape, dtype=x.dtype) | ||
input_.stop_gradient = False | ||
output = paddle.nn.functional.dropout( | ||
input_, p, training=(not is_test), mode=mode | ||
) | ||
if core._is_fwd_prim_enabled(): | ||
paddle.incubate.autograd.to_prim(mp.blocks) | ||
grad = paddle.static.gradients(output, input_)[0] | ||
exe = paddle.static.Executor(self.place) | ||
exe.run(sp) | ||
fwd, rev = exe.run( | ||
mp, feed={input_.name: x}, fetch_list=[output, grad] | ||
) | ||
return fwd, rev, mp | ||
|
||
core._set_prim_forward_enabled(False) | ||
desired_fwd, desired_rev, _ = dropout( | ||
self.x, self.p, self.is_test, self.mode, self.seed | ||
) | ||
|
||
core._set_prim_forward_enabled(True) | ||
actual_fwd, actual_rev, prog = dropout( | ||
self.x, self.p, self.is_test, self.mode, self.seed | ||
) | ||
|
||
self.assertTrue('dropout' not in [op.type for op in prog.block(0).ops]) | ||
|
||
np.testing.assert_allclose( | ||
actual_fwd.sum(), | ||
desired_fwd.sum(), | ||
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed | ||
atol=0, | ||
) | ||
np.testing.assert_allclose( | ||
actual_rev.sum(), | ||
desired_rev.sum(), | ||
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed | ||
atol=0, | ||
) | ||
|
||
core._set_prim_all_enabled(True) | ||
actual_fwd, actual_rev, _ = dropout( | ||
self.x, self.p, self.is_test, self.mode, self.seed | ||
) | ||
np.testing.assert_allclose( | ||
actual_fwd.sum(), | ||
desired_fwd.sum(), | ||
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed | ||
atol=0, | ||
) | ||
np.testing.assert_allclose( | ||
actual_rev.sum(), | ||
desired_rev.sum(), | ||
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed | ||
atol=0, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import functools | ||
import operator | ||
|
||
from paddle.fluid import core | ||
|
||
from .primitives import * # noqa: F403 | ||
from .primreg import REGISTER_COMPOSITE, lookup_composite | ||
|
||
|
@@ -149,3 +151,46 @@ def mean_composite(x, axis, keepdim): | |
dtype=sum_x.dtype, | ||
) | ||
return divide(sum_x, norm) | ||
|
||
|
||
@REGISTER_COMPOSITE('dropout') | ||
def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function signature should be consistent with dropout 2.x API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当前2.0API与legacy_ops.yaml定义中不一致,前向拆分是从op拆分到更细粒度op,所以签名暂无法与2.0API保持一致。后续legacy_ops迁移后,该部分会统一迁移 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. understand |
||
"""define composite rule of op dropout. | ||
upscale_in_train: | ||
train: out = input * mask / ( 1.0 - p ) | ||
inference: out = input | ||
downscale_in_infer | ||
train: out = input * mask | ||
inference: out = input * (1.0 - p) | ||
""" | ||
fix_seed = True if fix_seed is None else fix_seed | ||
seed = seed if fix_seed else 0 | ||
upscale_in_train = mode == "upscale_in_train" | ||
mask = bernoulli(shape=x.shape, dtype=x.dtype, p=p, seed=seed) | ||
|
||
if upscale_in_train: | ||
if not is_test: | ||
# Process p=1.0 for avoid devide zero error (x*mask/(1.0-p)) | ||
if p == 1.0: | ||
return 0.0 * x, zeros(x.shape, core.VarDesc.VarType.UINT8) | ||
else: | ||
return x * mask / (1.0 - p), cast( | ||
mask, core.VarDesc.VarType.UINT8 | ||
) | ||
else: | ||
return assign(x), cast(mask, core.VarDesc.VarType.UINT8) | ||
else: | ||
if not is_test: | ||
return x * mask, cast(mask, core.VarDesc.VarType.UINT8) | ||
else: | ||
return x * (1.0 - p), cast(mask, core.VarDesc.VarType.UINT8) | ||
|
||
|
||
def bernoulli(shape, dtype, p, seed=0): | ||
return cast( | ||
greater_equal( | ||
uniform(shape, dtype, min=0.0, max=1.0, seed=seed), | ||
fill_constant(shape, dtype, p), | ||
), | ||
dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充op check,确认dropout算子被拆解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done