diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 3ed68dbe8b8a..c97250028cb8 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -20,6 +20,7 @@ """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import +import io import re import atexit import ctypes @@ -76,6 +77,28 @@ def data_dir(): """ return os.getenv('MXNET_HOME', data_dir_default()) +class _Py2CompatibleUnicodeFileWriter(object): + """ + Wraps a file handle decorating the write command to unicode the content before writing. + This makes writing files opened with encoding='utf-8' compatible with Python 2 + """ + + def __init__(self, file_handle): + self._file_handle = file_handle + if sys.version_info[0] > 2: + self.unicode = str + else: + from functools import partial + # pylint: disable=undefined-variable + self.unicode = partial(unicode, encoding="utf-8") + # pylint: enable=undefined-variable + + def write(self, value): + self._file_handle.write(self.unicode(value)) + + def __getattr__(self, name): + return getattr(self._file_handle, name) + class _NullType(object): """Placeholder for arguments""" @@ -672,7 +695,7 @@ def get_module_file(module_name): module_path = module_name.split('.') module_path[-1] = 'gen_' + module_path[-1] file_name = os.path.join(path, '..', *module_path) + '.py' - module_file = open(file_name, 'w', encoding="utf-8") + module_file = _Py2CompatibleUnicodeFileWriter(io.open(file_name, 'w', encoding="utf-8")) dependencies = {'symbol': ['from ._internal import SymbolBase', 'from ..base import _Null'], 'ndarray': ['from ._internal import NDArrayBase',