Skip to content

Commit

Permalink
[BUGFIX] fix numpy op fallback bug when ndarray in kwargs (apache#20233)
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn authored and josephevans committed Feb 8, 2022
1 parent 7f50577 commit 5b6e717
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,32 @@ def _as_mx_np_array(object, ctx=None):
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))


def _as_onp_array(object):
"""Convert object to mxnet.numpy.ndarray."""
cur_ctx = None
def _as_onp_array(object, cur_ctx=None):
"""Convert object to numpy.ndarray."""
def _update_ctx(cur_ctx, tmp_ctx):
if cur_ctx is None:
cur_ctx = tmp_ctx
elif tmp_ctx is not None and cur_ctx != tmp_ctx:
raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
' input ndarrays are allocated on different devices: {} and {}'
.format(str(cur_ctx, tmp_ctx)))
return cur_ctx

if isinstance(object, ndarray):
return object.asnumpy(), object.ctx
elif isinstance(object, (list, tuple)):
tmp = []
for arr in object:
arr, tmp_ctx = _as_onp_array(arr)
# if isinstance(arr, (list, tuple)):
# raise TypeError('type {} not supported'.format(str(type(arr))))
arr, tmp_ctx = _as_onp_array(arr, cur_ctx)
tmp.append(arr)
if cur_ctx is None:
cur_ctx = tmp_ctx
elif tmp_ctx is not None and cur_ctx != tmp_ctx:
raise ValueError('Ambiguous to set the context for the output ndarray since' # pylint: disable=too-few-format-args
' input ndarrays are allocated on different devices: {} and {}'
.format(str(cur_ctx, tmp_ctx)))
cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
return object.__class__(tmp), cur_ctx
elif isinstance(object, dict):
tmp = dict()
for key, value in object.items():
value, tmp_ctx = _as_onp_array(value, cur_ctx)
tmp[key] = value
cur_ctx = _update_ctx(cur_ctx, tmp_ctx)
return object.__class__(tmp), cur_ctx
else:
return object, cur_ctx
Expand Down Expand Up @@ -284,13 +292,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(func)
new_args, cur_ctx = _as_onp_array(args)
cur_ctx = None
new_args, cur_ctx = _as_onp_array(args, cur_ctx)
new_kwargs, cur_ctx = _as_onp_array(kwargs, cur_ctx)
if cur_ctx is None:
raise ValueError('Unknown context for the input ndarrays. It is probably a bug. Please'
' create an issue on GitHub.')
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v
if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
"which is actually using official numpy's implementation.", func_name)
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD[func] = True
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, ctx=cur_ctx)
else:
Expand Down

0 comments on commit 5b6e717

Please sign in to comment.