diff --git a/slim/prune/export_model.py b/slim/prune/export_model.py index dd3c0ae463ba5..d8427d79a7175 100644 --- a/slim/prune/export_model.py +++ b/slim/prune/export_model.py @@ -63,6 +63,9 @@ def main(): test_fetches = model.test(feed_vars) infer_prog = infer_prog.clone(True) + exe.run(startup_prog) + checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) + pruned_params = FLAGS.pruned_params assert ( FLAGS.pruned_params is not None @@ -90,13 +93,9 @@ def main(): logger.info("pruned FLOPS: {}".format( float(base_flops - pruned_flops) / base_flops)) - exe.run(startup_prog) - checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) - dump_infer_config(FLAGS, cfg) save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) - if __name__ == '__main__': enable_static_mode() parser = ArgsParser()