diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index ec9a4d5e5ed6ac..92ec624614a83a 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -728,6 +728,17 @@ def test_scheduler(self): step_size_down=-1, scale_mode='test', ) + # check empty boundaries + with self.assertRaises(ValueError): + paddle.optimizer.lr.PiecewiseDecay(boundaries=[], values=[]) + # check non-empty boundaries but empty values + with self.assertRaises(ValueError): + paddle.optimizer.lr.PiecewiseDecay(boundaries=[100, 200], values=[]) + # check boundaries and values has same length + with self.assertRaises(ValueError): + paddle.optimizer.lr.PiecewiseDecay( + boundaries=[100, 200], values=[0.5, 0.1] + ) func_api_kwargs = [ ( diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index bc5f9020b7f305..a6698cfb735eea 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -391,6 +391,14 @@ class PiecewiseDecay(LRScheduler): """ def __init__(self, boundaries, values, last_epoch=-1, verbose=False): + if len(boundaries) == 0: + raise ValueError('The boundaries cannot be empty.') + + if len(values) <= len(boundaries): + raise ValueError( + f'The values have one more element than boundaries, but received len(values) [{len(values)}] < len(boundaries) + 1 [{len(boundaries) + 1}].' + ) + self.boundaries = boundaries self.values = values super().__init__(last_epoch=last_epoch, verbose=verbose) diff --git a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py index 9e98ee484105fc..2333777a2cca61 100644 --- a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py +++ b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py @@ -88,7 +88,7 @@ def test_warn_or_error(self): optim = paddle.optimizer.Adam( learning_rate=paddle.optimizer.lr.PiecewiseDecay( - [0.001, 0.0001], [5, 10] + [0.001, 0.0001], [5, 10, 10] ), parameters=net.parameters(), )