Skip to content

Commit

Permalink
fix get sub model (#733) (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored May 18, 2021
1 parent 0642d9a commit a8173f3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 66 deletions.
97 changes: 71 additions & 26 deletions paddleslim/nas/ofa/get_sub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config):
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
### 3. add channel in config
for inp in op.all_inputs():
n_ops = graph.next_ops(op)
if inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[inp._var.name].keys():
tmp = origin_model_config[inp._var.name]['expand_ratio']
if 'expand_ratio' in origin_model_config[
inp._var.name] or 'channel' in origin_model_config[
inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
inp._var.name] else 'expand_ratio'
tmp = origin_model_config[inp._var.name][key]
if len(inp._var.shape) > 1:
if inp._var.name in param_config.keys():
param_config[inp._var.name].append(tmp)
Expand All @@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config):
if next_inp._var.persistable == True:
if next_inp._var.name in origin_model_config.keys():
if 'expand_ratio' in origin_model_config[
next_inp._var.name].keys():
next_inp._var.
name] or 'channel' in origin_model_config[
next_inp._var.name]:
key = 'channel' if 'channel' in origin_model_config[
next_inp._var.name] else 'expand_ratio'
tmp = origin_model_config[next_inp._var.name][
'expand_ratio']
key]
pre = tmp if precedor is None else precedor
if len(next_inp._var.shape) > 1:
param_config[next_inp._var.name] = [pre]
Expand All @@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config):
return param_config


def get_actual_shape(transform, channel):
if transform == None:
channel = int(channel)
else:
if isinstance(transform, float):
channel = int(channel * transform)
else:
channel = int(transform)
return channel


def prune_params(model, param_config, super_model_sd=None):
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
Expand All @@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None):
in_exp = param_config[param.name][0]
out_exp = param_config[param.name][1]
if sublayer.__class__.__name__.lower() in CONV_TYPES:
in_chn = int(value.shape[1]) if in_exp == None else int(
value.shape[1] * in_exp)
out_chn = int(value.shape[
0]) if out_exp == None else int(value.shape[0] *
out_exp)
in_chn = get_actual_shape(in_exp, value.shape[1])
out_chn = get_actual_shape(out_exp, value.shape[0])
prune_value = super_value[:out_chn, :in_chn, ...] \
if super_model_sd != None else value[:out_chn, :in_chn, ...]
else:
in_chn = int(value.shape[0]) if in_exp == None else int(
value.shape[0] * in_exp)
out_chn = int(value.shape[
1]) if out_exp == None else int(value.shape[1] *
out_exp)
in_chn = get_actual_shape(in_exp, value.shape[0])
out_chn = get_actual_shape(out_exp, value.shape[1])
prune_value = super_value[:in_chn, :out_chn, ...] \
if super_model_sd != None else value[:in_chn, :out_chn, ...]
else:
out_chn = int(value.shape[0]) if param_config[param.name][
0] == None else int(value.shape[0] *
param_config[param.name][0])
out_chn = get_actual_shape(param_config[param.name][0],
value.shape[0])
prune_value = super_value[:out_chn, ...] \
if super_model_sd != None else value[:out_chn, ...]

Expand All @@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None):
if param.trainable:
param.clear_gradient()

### initialize param which not in sublayers, such as create persistable inputs by create_parameters
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if super_model_sd != None and len(super_model_sd) != 0:
for k, v in super_model_sd.items():
setattr(model, k, v)


def _is_depthwise(op):
"""Check if this op is depthwise conv.
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
so depthwise op cannot be consider as weight op
"""
if op.type() == 'depthwise_conv':
return True
elif 'conv' in op.type():
#if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
# return True
if 'conv' in op.type():
for inp in op.all_inputs():
if not inp._var.persistable and op.attr('groups') == inp._var.shape[
1]:
if inp._var.persistable and (
op.attr('groups') == inp._var.shape[0] and
op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
return True
return False

Expand All @@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights):
weights.append(inp._var.name)
return weights
return _find_weight_ops(pre_op, graph, weights)
return weights


def _find_pre_elementwise_add(op, graph):
Expand Down Expand Up @@ -236,3 +248,36 @@ def check_search_space(graph):
depthwise_conv = sorted(depthwise_conv)

return (final_search_space, depthwise_conv)


def broadcast_search_space(same_search_space, param2key, origin_config):
"""
Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}
Args:
same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
param2key(dict): the name of layers corresponds to the name of parameter.
origin_config(dict): the search space which can be searched.
"""
for per_ss in same_search_space:
for ss in per_ss[1:]:
key = param2key[ss]
pre_key = param2key[per_ss[0]]
if key in origin_config:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key].update({
'expand_ratio': origin_config[pre_key]['expand_ratio']
})
elif 'channel' in origin_config[pre_key]:
origin_config[key].update({
'channel': origin_config[pre_key]['channel']
})
else:
if 'expand_ratio' in origin_config[pre_key]:
origin_config[key] = {
'expand_ratio': origin_config[pre_key]['expand_ratio']
}
elif 'channel' in origin_config[pre_key]:
origin_config[key] = {
'channel': origin_config[pre_key]['channel']
}
56 changes: 20 additions & 36 deletions paddleslim/nas/ofa/ofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .utils.utils import search_idx
from ...common import get_logger
from ...core import GraphWrapper, dygraph2program
from .get_sub_model import get_prune_params_config, prune_params, check_search_space
from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space

_logger = get_logger(__name__, level=logging.INFO)

Expand Down Expand Up @@ -156,7 +156,6 @@ class OFA(OFABase):
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(model)
ofa_model = OFA(sp_model)
"""

def __init__(self,
Expand Down Expand Up @@ -461,6 +460,23 @@ def search(self, eval_func, condition):

def _export_sub_model_config(self, origin_model, config, input_shapes,
input_dtypes):
param2name = {}
for name, sublayer in origin_model.named_sublayers():
for param in sublayer.parameters(include_sublayers=False):
if name.split('.')[-1] == 'fn':
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name[param.name] = name[:-3]
else:
param2name[param.name] = name

program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)

same_config, _ = check_search_space(graph)
if same_config != None:
broadcast_search_space(same_config, param2name, config)

origin_model_config = {}
for name, sublayer in origin_model.named_sublayers():
if isinstance(sublayer, BaseBlock):
Expand All @@ -469,9 +485,6 @@ def _export_sub_model_config(self, origin_model, config, input_shapes,
if name in config.keys():
origin_model_config[param.name] = config[name]

program = dygraph2program(
origin_model, inputs=input_shapes, dtypes=input_dtypes)
graph = GraphWrapper(program)
param_prune_config = get_prune_params_config(graph, origin_model_config)
return param_prune_config

Expand All @@ -493,7 +506,6 @@ def export(self,
.. code-block:: python
from paddle.vision.models import mobilenet_v1
origin_model = mobilenet_v1()
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
Expand All @@ -505,7 +517,6 @@ def export(self,
origin_model = self.model
origin_model = origin_model._layers if isinstance(
origin_model, DataParallel) else origin_model

param_config = self._export_sub_model_config(origin_model, config,
input_shapes, input_dtypes)
prune_params(origin_model, param_config, super_sd)
Expand Down Expand Up @@ -602,7 +613,6 @@ def _clear_search_space(self, *inputs, **kwargs):
per_ss.append(key)
else:
_logger.info("{} not in ss".format(key))

if len(per_ss) != 0:
tmp_same_ss.append(per_ss)

Expand All @@ -626,33 +636,6 @@ def _clear_search_space(self, *inputs, **kwargs):
):
self._clear_width(name)

def _broadcast_ss(self):
""" broadcast search space after random sample."""
for per_ss in self._same_ss:
for ss in per_ss[1:]:
key = self._param2key[ss]
pre_key = self._param2key[per_ss[0]]
if key in self.current_config:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key].update({
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
})
elif 'channel' in self.current_config[pre_key]:
self.current_config[key].update({
'channel': self.current_config[pre_key]['channel']
})
else:
if 'expand_ratio' in self.current_config[pre_key]:
self.current_config[key] = {
'expand_ratio':
self.current_config[pre_key]['expand_ratio']
}
elif 'channel' in self.current_config[pre_key]:
self.current_config[key] = {
'channel': self.current_config[pre_key]['channel']
}

def forward(self, *inputs, **kwargs):
# ===================== teacher process =====================
teacher_output = None
Expand Down Expand Up @@ -692,7 +675,8 @@ def forward(self, *inputs, **kwargs):
kwargs['depth'] = self.current_config['depth']

if self._broadcast:
self._broadcast_ss()
broadcast_search_space(self._same_ss, self._param2key,
self.current_config)

student_output = self.model.forward(*inputs, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion paddleslim/nas/one_shot/one_shot_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
list<int>: The best tokens searched.
"""
super_net = None
for layer in model.sublayers(include_sublayers=False):
for layer in model.sublayers(include_self=True):
print("layer: {}".format(layer))
if isinstance(layer, OneShotSuperNet):
super_net = layer
Expand Down
5 changes: 2 additions & 3 deletions paddleslim/teachers/bert/model/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ def __init__(self, process_cmd, d_model, dropout_rate, name):

for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.append(
lambda x, y: x + y if y is not None else x)
self.functors.append(lambda x, y: x + y if y is not None else x)
self.exec_order += "a"
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
self.sublayers(include_self=True)),
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_ofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,5 +449,21 @@ def test_export_model(self):
assert len(self.ofa_model.ofa_layers) == 38


class TestExportCase1(unittest.TestCase):
def setUp(self):
model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = paddle.to_tensor(data_np)
self.ofa_model = OFA(model)
self.ofa_model.set_epoch(0)
outs, _ = self.ofa_model(self.data)
self.config = self.ofa_model.current_config

def test_export_model(self):
self.ofa_model.export(
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
assert len(self.ofa_model.ofa_layers) == 4


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions tests/test_ofa_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(self, x):
y = x + y
z = self.branch2(y)
z = z + y
z = self.out(z)
return z


Expand Down

0 comments on commit a8173f3

Please sign in to comment.