diff --git a/python/paddle/fluid/tests/unittests/test_glu.py b/python/paddle/fluid/tests/unittests/test_glu.py index b5cf6f3dca8b3b..64318858d19029 100644 --- a/python/paddle/fluid/tests/unittests/test_glu.py +++ b/python/paddle/fluid/tests/unittests/test_glu.py @@ -79,7 +79,7 @@ def glu_axis_size(self): paddle.nn.functional.glu(x, axis=256) def test_errors(self): - self.assertRaises(AssertionError, self.glu_axis_size) + self.assertRaises(ValueError, self.glu_axis_size) if __name__ == '__main__': diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 934817e764f5a7..9d18c2386414cb 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1622,7 +1622,13 @@ def glu(x, axis=-1, name=None): check_variable_and_dtype( x, 'input', ['float16', 'float32', 'float64'], "glu" ) - assert axis < len(x.shape), "axis must < rank(x)" + rank = len(x.shape) + if not (-rank <= axis < rank): + raise ValueError( + "Expected value range of `axis` is [{}, {}), but received axis: {}".format( + -rank, rank, axis + ) + ) a, b = chunk(x, 2, axis=axis, name=name) gate = sigmoid(b, name=name) out = paddle.multiply(a, gate, name=name)