forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcommon.py
373 lines (306 loc) · 13.6 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import sys, os, logging, functools
import multiprocessing as mp
import mxnet as mx
import numpy as np
import random
import shutil
from mxnet.base import MXNetError
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../common/'))
sys.path.insert(0, os.path.join(curr_path, '../../../python'))
import models
from contextlib import contextmanager
import pytest
from tempfile import TemporaryDirectory
import locale
xfail_when_nonstandard_decimal_separator = pytest.mark.xfail(
locale.localeconv()["decimal_point"] != ".",
reason="Some operators break when the decimal separator is set to anything other than \".\". "
"These operators should be rewritten to utilize the new FFI. Please see #18097 for more "
"information."
)
def assertRaises(expected_exception, func, *args, **kwargs):
try:
func(*args, **kwargs)
except expected_exception as e:
pass
else:
# Did not raise exception
assert False, "%s did not raise %s" % (func.__name__, expected_exception.__name__)
def default_logger():
"""A logger used to output seed information to logs."""
logger = logging.getLogger(__name__)
# getLogger() lookups will return the same logger, but only add the handler once.
if not len(logger.handlers):
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s'))
logger.addHandler(handler)
if (logger.getEffectiveLevel() == logging.NOTSET):
logger.setLevel(logging.INFO)
return logger
@contextmanager
def random_seed(seed=None):
"""
Runs a code block with a new seed for np, mx and python's random.
Parameters
----------
seed : the seed to pass to np.random, mx.random and python's random.
To impose rng determinism, invoke e.g. as in:
with random_seed(1234):
...
To impose rng non-determinism, invoke as in:
with random_seed():
...
Upon conclusion of the block, the rng's are returned to
a state that is a function of their pre-block state, so
any prior non-determinism is preserved.
"""
try:
next_seed = np.random.randint(0, np.iinfo(np.int32).max)
if seed is None:
np.random.seed()
seed = np.random.randint(0, np.iinfo(np.int32).max)
logger = default_logger()
logger.debug('Setting np, mx and python random seeds = %s', seed)
np.random.seed(seed)
mx.random.seed(seed)
random.seed(seed)
yield
finally:
# Reinstate prior state of np.random and other generators
np.random.seed(next_seed)
mx.random.seed(next_seed)
random.seed(next_seed)
def _assert_raise_cuxx_version_not_satisfied(min_version, cfg):
def less_than(version_left, version_right):
"""Compares two version strings in the format num(.[num])*"""
if not version_left or not version_right:
return False
left = version_left.split(".")
right = version_right.split(".")
# 0 pad shortest version - e.g.
# less_than("9.1", "9.1.9") == less_than("9.1.0", "9.1.9")
longest = max(len(left), len(right))
left.extend([0] * (longest - len(left)))
right.extend([0] * (longest - len(right)))
# compare each of the version components
for l, r in zip(left, right):
if l == r:
continue
return int(l) < int(r)
return False
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
cuxx_off = os.getenv(cfg['TEST_OFF_ENV_VAR']) == 'true'
cuxx_env_version = os.getenv(cfg['VERSION_ENV_VAR'], None if cuxx_off else cfg['DEFAULT_VERSION'])
cuxx_test_disabled = cuxx_off or less_than(cuxx_env_version, min_version)
if not cuxx_test_disabled or mx.context.current_context().device_type == 'cpu':
orig_test(*args, **kwargs)
else:
pytest.raises((MXNetError, RuntimeError), orig_test, *args, **kwargs)
return test_new
return test_helper
def assert_raises_cudnn_not_satisfied(min_version):
return _assert_raise_cuxx_version_not_satisfied(min_version, {
'TEST_OFF_ENV_VAR': 'CUDNN_OFF_TEST_ONLY',
'VERSION_ENV_VAR': 'CUDNN_VERSION',
'DEFAULT_VERSION': '7.3.1'
})
def assert_raises_cuda_not_satisfied(min_version):
return _assert_raise_cuxx_version_not_satisfied(min_version, {
'TEST_OFF_ENV_VAR': 'CUDA_OFF_TEST_ONLY',
'VERSION_ENV_VAR': 'CUDA_VERSION',
'DEFAULT_VERSION': '10.1'
})
def with_seed(seed=None):
"""
A decorator for test functions that manages rng seeds.
Parameters
----------
seed : the seed to pass to np.random and mx.random
This tests decorator sets the np, mx and python random seeds identically
prior to each test, then outputs those seeds if the test fails or
if the test requires a fixed seed (as a reminder to make the test
more robust against random data).
@with_seed()
def test_ok_with_random_data():
...
@with_seed(1234)
def test_not_ok_with_random_data():
...
Use of the @with_seed() decorator for all tests creates
tests isolation and reproducability of failures. When a
test fails, the decorator outputs the seed used. The user
can then set the environment variable MXNET_TEST_SEED to
the value reported, then rerun the test with:
pytest --verbose --capture=no <test_module_name.py>::<failing_test>
To run a test repeatedly, set MXNET_TEST_COUNT=<NNN> in the environment.
To see the seeds of even the passing tests, add '--log-level=DEBUG' to pytest.
"""
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
test_count = int(os.getenv('MXNET_TEST_COUNT', '1'))
env_seed_str = os.getenv('MXNET_TEST_SEED')
for i in range(test_count):
if seed is not None:
this_test_seed = seed
log_level = logging.INFO
elif env_seed_str is not None:
this_test_seed = int(env_seed_str)
log_level = logging.INFO
else:
this_test_seed = np.random.randint(0, np.iinfo(np.int32).max)
log_level = logging.DEBUG
post_test_state = np.random.get_state()
np.random.seed(this_test_seed)
mx.random.seed(this_test_seed)
random.seed(this_test_seed)
logger = default_logger()
# 'pytest --logging-level=DEBUG' shows this msg even with an ensuing core dump.
test_count_msg = '{} of {}: '.format(i+1,test_count) if test_count > 1 else ''
test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}'
' to reproduce.').format(test_count_msg, this_test_seed)
logger.log(log_level, test_msg)
try:
orig_test(*args, **kwargs)
except:
# With exceptions, repeat test_msg at WARNING level to be sure it's seen.
if log_level < logging.WARNING:
logger.warning(test_msg)
raise
finally:
# Provide test-isolation for any test having this decorator
mx.nd.waitall()
np.random.set_state(post_test_state)
return test_new
return test_helper
def setup_module():
"""
A function with a 'magic name' executed automatically before each pytest module
(file of tests) that helps reproduce a test segfault by setting and outputting the rng seeds.
The segfault-debug procedure on a module called test_module.py is:
1. run "pytest --verbose test_module.py". A seg-faulting output might be:
[INFO] np, mx and python random seeds = 4018804151
test_module.test1 ... ok
test_module.test2 ... Illegal instruction (core dumped)
2. Copy the module-starting seed into the next command, then run:
MXNET_MODULE_SEED=4018804151 pytest --logging-level=DEBUG --verbose test_module.py
Output might be:
[WARNING] **** module-level seed is set: all tests running deterministically ****
[INFO] np, mx and python random seeds = 4018804151
test_module.test1 ... [DEBUG] np and mx random seeds = 3935862516
ok
test_module.test2 ... [DEBUG] np and mx random seeds = 1435005594
Illegal instruction (core dumped)
3. Copy the segfaulting-test seed into the command:
MXNET_TEST_SEED=1435005594 pytest --logging-level=DEBUG --verbose test_module.py:test2
Output might be:
[INFO] np, mx and python random seeds = 2481884723
test_module.test2 ... [DEBUG] np and mx random seeds = 1435005594
Illegal instruction (core dumped)
3. Finally reproduce the segfault directly under gdb (might need additional os packages)
by editing the bottom of test_module.py to be
if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
test2()
MXNET_TEST_SEED=1435005594 gdb -ex r --args python test_module.py
4. When finished debugging the segfault, remember to unset any exported MXNET_ seed
variables in the environment to return to non-deterministic testing (a good thing).
"""
module_seed_str = os.getenv('MXNET_MODULE_SEED')
logger = default_logger()
if module_seed_str is None:
seed = np.random.randint(0, np.iinfo(np.int32).max)
else:
seed = int(module_seed_str)
logger.warning('*** module-level seed is set: all tests running deterministically ***')
logger.info('Setting module np/mx/python random seeds, use MXNET_MODULE_SEED=%s to reproduce.', seed)
np.random.seed(seed)
mx.random.seed(seed)
random.seed(seed)
# The MXNET_TEST_SEED environment variable will override MXNET_MODULE_SEED for tests with
# the 'with_seed()' decoration. Inform the user of this once here at the module level.
if os.getenv('MXNET_TEST_SEED') is not None:
logger.warning('*** test-level seed set: all "@with_seed()" tests run deterministically ***')
def teardown_module():
"""
A function with a 'magic name' executed automatically after each pytest test module.
It waits for all operations in one file to finish before carrying on the next.
"""
mx.nd.waitall()
def run_in_spawned_process(func, env, *args):
"""
Helper function to run a test in its own process.
Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend.
Adds a seed as first arg to propagate determinism.
Parameters
----------
func : function to run in a spawned process.
env : dict of additional environment values to set temporarily in the environment before exec.
args : args to pass to the function.
Returns
-------
Whether the python version supports running the function as a spawned process.
This routine calculates a random seed and passes it into the test as a first argument. If the
test uses random values, it should include an outer 'with random_seed(seed):'. If the
test needs to return values to the caller, consider use of shared variable arguments.
"""
try:
mpctx = mp.get_context('spawn')
except:
print('SKIP: python%s.%s lacks the required process fork-exec support ... ' %
sys.version_info[0:2], file=sys.stderr, end='')
return False
else:
seed = np.random.randint(0,1024*1024*1024)
orig_environ = os.environ.copy()
try:
for (key, value) in env.items():
os.environ[key] = str(value)
# Prepend seed as first arg
p = mpctx.Process(target=func, args=(seed,)+args)
p.start()
p.join()
assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__)
finally:
os.environ.clear()
os.environ.update(orig_environ)
return True
def retry(n):
"""Retry n times before failing for stochastic test cases."""
# TODO(szha): replace with flaky
# /~https://github.com/apache/incubator-mxnet/issues/17803
assert n > 0
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
"""Wrapper for tests function."""
for i in range(n):
try:
orig_test(*args, **kwargs)
return
except AssertionError as e:
if i == n-1:
raise e
mx.nd.waitall()
return test_new
return test_helper