-
Notifications
You must be signed in to change notification settings - Fork 349
/
Copy pathimportance.py
150 lines (128 loc) · 5.54 KB
/
importance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
def compute_neuron_head_importance(args, model, dev_ds, place, model_cfg):
n_layers, n_heads = model_cfg['num_hidden_layers'], model_cfg[
'num_attention_heads']
head_importance = paddle.zeros(shape=[n_layers, n_heads], dtype='float32')
head_mask = paddle.ones(shape=[n_layers, n_heads], dtype='float32')
head_mask.stop_gradient = False
intermediate_weight = []
intermediate_bias = []
output_weight = []
for name, w in model.named_parameters():
if 'ffn.i' in name:
if len(w.shape) > 1:
intermediate_weight.append(w)
else:
intermediate_bias.append(w)
if 'ffn.o' in name:
if len(w.shape) > 1:
output_weight.append(w)
neuron_importance = []
for w in intermediate_weight:
neuron_importance.append(np.zeros(shape=[w.shape[1]], dtype='float32'))
eval_task_names = ('mnli', 'mnli-mm') if args.task == 'mnli' else (
args.task, )
for eval_task in eval_task_names:
for batch in dev_ds.start(place):
ids, sids, label = batch
out = model(
ids,
sids,
labels=label,
head_mask=head_mask,
num_layers=model_cfg['num_hidden_layers'])
loss = out[0]
loss.backward()
head_importance += paddle.abs(
paddle.to_tensor(head_mask.gradient()))
for w1, b1, w2, current_importance in zip(
intermediate_weight, intermediate_bias, output_weight,
neuron_importance):
current_importance += np.abs(
(np.sum(w1.numpy() * w1.gradient(), axis=0) + b1.numpy() *
b1.gradient()))
current_importance += np.abs(
np.sum(w2.numpy() * w2.gradient(), axis=1))
return head_importance, neuron_importance
def reorder_neuron_head(model, head_importance, neuron_importance):
# reorder heads and ffn neurons
for layer, current_importance in enumerate(neuron_importance):
# reorder heads
idx = paddle.argsort(head_importance[layer], descending=True)[-1]
#model.encoder_stack.block[layer].attn.reorder_heads(idx)
reorder_head(model.encoder_stack.block[layer].attn, idx)
# reorder neurons
idx = paddle.argsort(
paddle.to_tensor(current_importance), descending=True)[-1]
#model.encoder_stack.block[layer].ffn.reorder_neurons(idx)
reorder_neuron(model.encoder_stack.block[layer].ffn, idx)
def reorder_head(layer, idx):
n, a = layer.n_head, layer.d_key
index = paddle.reshape(
paddle.index_select(
paddle.reshape(
paddle.arange(
0, n * a, dtype='int64'), shape=[n, a]),
idx,
axis=0),
shape=[-1])
def reorder_head_matrix(linearLayer, index, dim=1):
W = paddle.index_select(linearLayer.weight, index, axis=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = paddle.assign(linearLayer.bias).detach()
else:
b = paddle.assign(
L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_head_matrix(
layer.q.fn if hasattr(layer.q, 'fn') else layer.q, index)
reorder_head_matrix(
layer.k.fn if hasattr(layer.k, 'fn') else layer.k, index)
reorder_head_matrix(
layer.v.fn if hasattr(layer.v, 'fn') else layer.v, index)
reorder_head_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)
def reorder_neuron(layer, index, dim=0):
def reorder_neurons_matrix(linearLayer, index, dim):
W = paddle.index_select(linearLayer.weight, index, axis=dim).detach()
if linearLayer.bias is not None:
if dim == 0:
b = paddle.assign(linearLayer.bias).detach()
else:
b = paddle.assign(
L.index_select(
linearLayer.bias, index, dim=0)).detach()
linearLayer.weight.stop_gradient = True
linearLayer.weight.set_value(W)
linearLayer.weight.stop_gradient = False
if linearLayer.bias is not None:
linearLayer.bias.stop_gradient = True
linearLayer.bias.set_value(b)
linearLayer.bias.stop_gradient = False
reorder_neurons_matrix(
layer.i.fn if hasattr(layer.i, 'fn') else layer.i, index, dim=1)
reorder_neurons_matrix(
layer.o.fn if hasattr(layer.o, 'fn') else layer.o, index, dim=0)