This repository has been archived by the owner on Mar 31, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsyncbn.py
332 lines (282 loc) · 11.9 KB
/
syncbn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import threading
import mxnet as mx
from mxnet import autograd, test_utils, autograd
from mxnet.ndarray import NDArray
from mxnet.gluon import HybridBlock
__all__ = ['ModelDataParallel', 'BatchNorm']
class ModelDataParallel(object):
"""Data parallelism
Inputs and outputs are both list of NDArrays in different contexts.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards
pass, gradients from each replica are summed into the original module.
Parameters
----------
module : object
Network to be parallelized.
ctx : list
A list of contexts to use.
Inputs:
- **inputs**: list of input (NDArrays)
Outputs:
- **outputs**: list of output (NDArrays)
Example::
>>> ctx = [mx.gpu(0), mx.gpu(1)]
>>> net = ModelDataParallel(model, ctx=ctx)
>>> x = gluon.utils.split_and_load(data, ctx_list=ctx)
>>> y = net(x)
"""
def __init__(self, module, ctx, sync=True):
self.ctx = ctx
module.collect_params().reset_ctx(ctx=ctx)
self.module = module
self.sync = sync
def __call__(self, inputs):
if self.sync:
return _parallel_apply(self.module, inputs)
else:
if isinstance(inputs, NDArray):
return self.module(inputs)
if len(inputs) == 1:
return (self.module(inputs[0]), )
outputs = [self.module(X) for X in inputs]
return outputs
class BatchNorm(HybridBlock):
"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Parameters
----------
axis : int, default 1
The axis that should be normalized. This is typically the channels
(C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
set `axis=1` in `BatchNorm`. If `layout='NHWC'`, then set `axis=3`.
momentum: float, default 0.9
Momentum for the moving average.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
use_global_stats: bool, default False
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
moving_mean_initializer: str or `Initializer`, default 'zeros'
Initializer for the moving mean.
moving_variance_initializer: str or `Initializer`, default 'ones'
Initializer for the moving variance.
in_channels : int, default 0
Number of channels (feature maps) in input data. If not specified,
initialization will be deferred to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
nGPUs : int, default number of visible GPUs
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
"""
def __init__(self, momentum=0.9, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, nGPUs=None, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self._kwargs = {'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale}
if in_channels != 0:
self.in_channels = in_channels
self.eps = epsilon
self.momentum = momentum
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(in_channels,),
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(in_channels,),
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
if nGPUs is None:
nGPUs = self._get_nGPUs()
self.xsum = SharedTensor(nGPUs)
self.xsqu = SharedTensor(nGPUs)
self.updater = SharedUpdater(nGPUs)
def _get_nGPUs(self):
# caution: if not using all the GPUs, please mannually set nGPUs
nGPUs = len(test_utils.list_gpus())
# for CPU
nGPUs = nGPUs if nGPUs > 0 else 1
return nGPUs
def cast(self, dtype):
if np.dtype(dtype).name == 'float16':
dtype = 'float32'
super(BatchNorm, self).cast(dtype)
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
if autograd.is_training():
isum, isqu = F.SumSquare(x)
# reduce sum for E(x) and E(x^2)
idsum = self.xsum.push(isum)
idsqu = self.xsqu.push(isqu)
osum = self.xsum.get(F, idsum)
osqu = self.xsqu.get(F, idsqu)
assert(len(self.xsum) == len(self.xsqu))
N = len(self.xsum)*x.shape[0]*x.shape[2]*x.shape[3]
# calc mean and std
mean = osum / N
sumvar = osqu - osum * osum / N
std = F.sqrt(sumvar / N + self.eps)
# update running mean and var
with autograd.pause():
unbias_var = sumvar / (N - 1)
ctx = x.context
self.updater(self.running_mean, self.running_var, mean, unbias_var,
self.momentum, ctx)
return F.DecoupleBatchNorm(x, gamma, beta, mean, std,
name='fwd', **self._kwargs)
else:
ctx = x.context
return F.BatchNorm(x, gamma, beta, running_mean, running_var, name='fwd',
**self._kwargs)
def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels if in_channels else None)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
class SharedUpdater:
# update only once
def __init__(self, nGPUs):
self.mutex = threading.Lock()
self.nGPUs = nGPUs
self._clear()
def _clear(self):
self.tasks = self.nGPUs
def __call__(self, running_mean, running_var, mean, unbias_var, momentum, ctx):
with self.mutex:
if self.tasks == self.nGPUs:
running_mean.set_data(momentum * running_mean.data(ctx) + \
(1.0 - momentum) * mean)
running_var.set_data(momentum * running_var.data(ctx) + \
(1.0 - momentum) * unbias_var)
self.tasks -= 1
if self.tasks == 0:
self._clear()
class SharedTensor:
def __init__(self, nGPUs):
self.mutex = threading.Lock()
self.all_tasks_done = threading.Condition(self.mutex)
self.nGPUs = nGPUs
self._clear()
def _clear(self):
self.list = []
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs
def push(self, t):
with self.mutex:
if self.push_tasks == 0:
self._clear()
self.list.append(t)
idx = len(self.list) - 1
self.push_tasks -= 1
with self.all_tasks_done:
if self.push_tasks == 0:
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
return idx
def _reduce(self, F):
with self.mutex:
if self.reduce_tasks == 1:
assert(len(self.list) == self.nGPUs)
self.list = F.AllReduce(*self.list)
for xi in self.list:
# mannually attach grad to avoid wrong allocation
xi.attach_grad()
xi.wait_to_read()
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1
with self.all_tasks_done:
if self.reduce_tasks == 0:
self.all_tasks_done.notify_all()
while self.reduce_tasks:
self.all_tasks_done.wait()
def get(self, F, idx):
self._reduce(F)
return self.list[idx]
def test(self):
print('self.list', self.list)
def __len__(self):
return len(self.list)
def __repr__(self):
return 'SharedTensor'
def _parallel_apply(module, inputs, kwargs_tup=None):
if kwargs_tup:
assert len(inputs) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(inputs)
if isinstance(inputs, NDArray):
return module(inputs, **kwargs_tup[0])
if len(inputs) == 1:
return (module(inputs[0], **kwargs_tup[0]), )
lock = threading.Lock()
results = {}
def _worker(i, module, input, kwargs, results, is_training, lock):
try:
if is_training:
with autograd.record():
output = module(input, **kwargs)
output.wait_to_read()
else:
output = module(input, **kwargs)
output.wait_to_read()
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if autograd.is_training():
is_training = True
else:
is_training = False
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, results,
is_training, lock),
)
for i, (input, kwargs) in
enumerate(zip(inputs, kwargs_tup))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs