Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#11 from wanghaoshuang/flops_pruning
Browse files Browse the repository at this point in the history
Add flops pruning
  • Loading branch information
wanghaoshuang authored Feb 25, 2019
2 parents fd8799b + 04887c0 commit 8c19599
Show file tree
Hide file tree
Showing 3 changed files with 682 additions and 227 deletions.
148 changes: 112 additions & 36 deletions python/paddle/fluid/contrib/slim/core/compress_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from .... import io
from .... import profiler
from ....data_feeder import DataFeeder
from ..graph import get_executor, ImitationGraph
from ..graph import *
from config import ConfigFactory
import numpy as np
from collections import Iterable
import time
import os
import logging
import sys
import pickle

__all__ = ['Context', 'CompressPass']

Expand All @@ -32,6 +33,30 @@
logger = logging.getLogger(__name__)


def cached_reader(reader, sampled_rate, cache_path, cached_id):
np.random.seed(cached_id)
cache_path = cache_path + "/" + str(cached_id)
logger.info('read data from: {}'.format(cache_path))

def s_reader():
if os.path.isdir(cache_path):
for file_name in open(cache_path + "/list"):
yield np.load(cache_path + '/' + file_name.strip())
else:
os.makedirs(cache_path)
list_file = open(cache_path + "/list", 'w')
batch = 0
dtype = None
for data in reader():
if batch == 0 or (np.random.uniform() < sampled_rate):
np.save(cache_path + '/batch' + str(batch), data)
list_file.write('batch' + str(batch) + '.npy\n')
batch += 1
yield data

return s_reader


class Context(object):
"""
The context in the process of compression.
Expand All @@ -55,7 +80,8 @@ def __init__(self,
# The total number of epoches to be trained.
self.epoch = 0
# Current epoch
self.epoch_id = 0
# self.epoch_id = -1
self.epoch_id = -1
# Current batch
self.batch_id = 0

Expand All @@ -71,8 +97,31 @@ def __init__(self,
self.train_optimizer = train_optimizer
self.distiller_optimizer = distiller_optimizer
self.optimize_graph = None

def run_eval_graph(self):
self.cache_path = './eval_cache'
self.eval_results = {}

def to_file(self, file_name):
data = {}
data['epoch_id'] = self.epoch_id
data['eval_results'] = self.eval_results
with open(file_name, 'wb') as context_file:
pickle.dump(data, context_file)

def from_file(self, file_name):
with open(file_name) as context_file:
data = pickle.load(context_file)
self.epoch_id = data['epoch_id']
self.eval_results = data['eval_results']

def eval_converged(self, metric_name, delta=0.001):
if (metric_name not in self.eval_results
) or len(self.eval_results[metric_name]) < 2:
return False
results = self.eval_results[metric_name][-2:]
logger.info('Latest evaluations: {}'.format(results))
return abs(results[1] - results[0]) / results[0] < delta

def run_eval_graph(self, sampled_rate=None, cached_id=0):
logger.info(
'--------------------------Running evaluation----------------------')
assert self.eval_graph is not None
Expand All @@ -81,7 +130,12 @@ def run_eval_graph(self):
executor = get_executor(eval_graph, self.place, parallel=True)
results = []
batch_id = 0
for data in self.eval_reader():
s_time = time.time()
reader = self.eval_reader
if sampled_rate:
reader = cached_reader(reader, sampled_rate, self.cache_path,
cached_id)
for data in reader():
result = executor.run(eval_graph, data=data)
result = [np.mean(r) for r in result]
results.append(result)
Expand Down Expand Up @@ -131,6 +185,7 @@ def __init__(self,
eval_feed_list=None,
eval_fetch_list=None,
teacher_programs=[],
checkpoint_path='./checkpoints',
train_optimizer=None,
distiller_optimizer=None):
self.strategies = []
Expand All @@ -153,7 +208,7 @@ def __init__(self,
self.teacher_graphs.append(ImitationGraph(teacher, scope=scope))

self.checkpoint = None
self.model_save_dir = './checkpoints/'
self.checkpoint_path = checkpoint_path
self.eval_epoch = 1

self.train_optimizer = train_optimizer
Expand All @@ -177,30 +232,55 @@ def config(self, config_file):
self.add_strategy(strategy)
if 'init_epoch' in factory.compress_pass:
self.init_epoch = factory.compress_pass['init_epoch']
if 'model_save_dir' in factory.compress_pass:
self.model_save_dir = factory.compress_pass['model_save_dir']
if 'checkpoint_path' in factory.compress_pass:
self.checkpoint_path = factory.compress_pass['checkpoint_path']

def _load_checkpoint(self, context):
if self.checkpoint:
exe = get_executor(
context.train_graph, context.place, parallel=False)
io.load_persistables(
exe.exe,
self.checkpoint,
main_program=context.train_graph.program)
print("Loaded checkpoint from: {}".format(self.checkpoint))
logger.info('_load_checkpoint')
strategies = self.strategies
if self.checkpoint_path:
checkpoints = [
dir for dir in os.listdir(self.checkpoint_path)
if os.path.isdir(os.path.join(self.checkpoint_path, dir))
]
logger.info('self.checkpoint_path: {}'.format(self.checkpoint_path))
logger.info('checkpoints: {}'.format(checkpoints))
if len(checkpoints) > 0:
latest = max(checkpoints)
latest_ck_path = os.path.join(self.checkpoint_path, str(latest))

model_path = os.path.join(latest_ck_path, 'model')
context_path = os.path.join(latest_ck_path, 'context')
strategy_path = os.path.join(latest_ck_path, 'strategies')
context.from_file(context_path)
with open(strategy_path, 'rb') as strategy_file:
strategies = pickle.load(strategy_file)

exe = get_executor(
context.train_graph, context.place, parallel=False)
load_persistables(context.train_graph, model_path, exe)
update_param_shape(context.eval_graph)
update_depthwise_conv(context.eval_graph)
logger.info("Loaded checkpoint from: {}".format(
self.checkpoint_path))
return context, strategies

def _save_checkpoint(self, context):
if context.epoch_id % 1 == 0 and self.model_save_dir:
model_path = os.path.join(
self.model_save_dir,
str(context.epoch_id) + "_" + str(context.batch_id))
if context.epoch_id % 1 == 0 and self.checkpoint_path:
checkpoint_path = os.path.join(self.checkpoint_path,
str(context.epoch_id))
model_path = os.path.join(checkpoint_path, 'model')
context_path = os.path.join(checkpoint_path, 'context')
strategy_path = os.path.join(checkpoint_path, 'strategies')
if not os.path.isdir(model_path):
os.makedirs(model_path)
exe = get_executor(context.train_graph, context.place, False)
io.save_persistables(
exe.exe, model_path, main_program=context.train_graph.program)
logger.info('Saved checkpoint to: {}'.format(model_path))
exe = get_executor(
context.train_graph, context.place, parallel=False)
save_persistables(context.train_graph, model_path, exe)
context.to_file(context_path)
with open(strategy_path, 'wb') as strategy_file:
pickle.dump(self.strategies, strategy_file)
logger.info('Saved checkpoint to: {}'.format(checkpoint_path))

def _train_one_epoch(self, context):
if context.train_graph is None:
Expand All @@ -223,7 +303,6 @@ def _train_one_epoch(self, context):
executor = get_executor(
context.optimize_graph, self.place, parallel=True)

# with profiler.profiler('GPU', 'total'):
for data in context.train_reader():
for strategy in self.strategies:
strategy.on_batch_begin(context)
Expand All @@ -239,13 +318,16 @@ def _train_one_epoch(self, context):
strategy.on_batch_end(context)
context.batch_id += 1
context.batch_id = 0
self._save_checkpoint(context)
logger.info(
'-----------------------Finish training one epoch-----------------------'
)

def _eval(self, context):
results, names = context.run_eval_graph()
for name, result in zip(names, results):
if name not in context.eval_results:
context.eval_results[name] = []
context.eval_results[name].append(result)

def run(self):

Expand All @@ -259,32 +341,26 @@ def run(self):
train_optimizer=self.train_optimizer,
distiller_optimizer=self.distiller_optimizer)

self._load_checkpoint(context)
context, self.strategies = self._load_checkpoint(context)

self.executor = get_executor(
self.train_graph, self.place, parallel=True)
context.put('executor', self.executor)

if self.teacher_graphs:
context.put('teachers', self.teacher_graphs)

for strategy in self.strategies:
strategy.on_compression_begin(context)

for epoch in range(self.init_epoch, self.epoch):
start = context.epoch_id + 1
for epoch in range(start, self.epoch):
context.epoch_id = epoch
print('context.epoch_id: {}'.format(context.epoch_id))
for strategy in self.strategies:
strategy.on_epoch_begin(context)

self._train_one_epoch(context)

for strategy in self.strategies:
strategy.on_epoch_end(context)

if self.eval_epoch and epoch % self.eval_epoch == 0:
self._eval(context)
self._save_checkpoint(context)
for strategy in self.strategies:
strategy.on_compression_end(context)

return context.eval_graph
Loading

0 comments on commit 8c19599

Please sign in to comment.