From 3d0755b187adfd00ece82a50f753c36706c595c5 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Thu, 18 Aug 2022 10:23:10 +0800 Subject: [PATCH] Add Early Stop in AutoCompression (#1358) --- paddleslim/auto_compression/compressor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index dc825d08b1a5d..8fbfacc463804 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -714,7 +714,10 @@ def _start_train(self, train_program_info, test_program_info, strategy, best_metric = -1.0 total_epochs = train_config.epochs if train_config.epochs else 100 total_train_iter = 0 + stop_training = False for epoch_id in range(total_epochs): + if stop_training: + break for batch_id, data in enumerate(self.train_dataloader()): np_probs_float, = self._exe.run(train_program_info.program, \ feed=data, \ @@ -760,6 +763,10 @@ def _start_train(self, train_program_info, test_program_info, strategy, abs(best_metric - self.metric_before_compressed) ) / self.metric_before_compressed <= 0.005: + _logger.info( + "The error rate between the compressed model and original model is less than 5%. The training process ends." + ) + stop_training = True break else: _logger.info( @@ -767,14 +774,18 @@ def _start_train(self, train_program_info, test_program_info, strategy, format(epoch_id, metric, best_metric)) if train_config.target_metric is not None: if metric > float(train_config.target_metric): + stop_training = True + _logger.info( + "The metric of compressed model has reached the target metric. The training process ends." + ) break else: _logger.warning( "Not set eval function, so unable to test accuracy performance." ) - if train_config.train_iter and total_train_iter >= train_config.train_iter: - epoch_id = total_epochs + if (train_config.train_iter and total_train_iter >= + train_config.train_iter) or stop_training: break if 'unstructure' in self._strategy or train_config.sparse_model: