Skip to content

Commit

Permalink
array API fixes for astype
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed May 4, 2023
1 parent 316c63d commit 9cd9078
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
73 changes: 43 additions & 30 deletions xarray/core/accessor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

import numpy as np

from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.types import T_DataArray

Expand Down Expand Up @@ -2085,13 +2086,16 @@ def _get_res_multi(val, pat):
else:
# dtype MUST be object or strings can be truncated
# See: /~https://github.com/numpy/numpy/issues/8352
return self._apply(
func=_get_res_multi,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxgroups},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=_get_res_multi,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxgroups},
),
self._obj.dtype.kind,
)

def extractall(
self,
Expand Down Expand Up @@ -2258,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype):

return res

return self._apply(
# dtype MUST be object or strings can be truncated
# See: /~https://github.com/numpy/numpy/issues/8352
func=_get_res,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[group_dim, match_dim]],
output_sizes={group_dim: maxgroups, match_dim: maxcount},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
# dtype MUST be object or strings can be truncated
# See: /~https://github.com/numpy/numpy/issues/8352
func=_get_res,
func_args=(pat,),
dtype=np.object_,
output_core_dims=[[group_dim, match_dim]],
output_sizes={group_dim: maxgroups, match_dim: maxcount},
),
self._obj.dtype.kind,
)

def findall(
self,
Expand Down Expand Up @@ -2385,13 +2392,16 @@ def _partitioner(

# dtype MUST be object or strings can be truncated
# See: /~https://github.com/numpy/numpy/issues/8352
return self._apply(
func=arrfunc,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: 3},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=arrfunc,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: 3},
),
self._obj.dtype.kind,
)

def partition(
self,
Expand Down Expand Up @@ -2510,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype):

# dtype MUST be object or strings can be truncated
# See: /~https://github.com/numpy/numpy/issues/8352
return self._apply(
func=_dosplit,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxsplit},
).astype(self._obj.dtype.kind)
return duck_array_ops.astype(
self._apply(
func=_dosplit,
func_args=(sep,),
dtype=np.object_,
output_core_dims=[[dim]],
output_sizes={dim: maxsplit},
),
self._obj.dtype.kind,
)

def split(
self,
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
pads = [(0, 0) if d != dim else dim_pad for d in self.dims]

data = np.pad(
trimmed_data.astype(dtype),
duck_array_ops.astype(trimmed_data, dtype),
pads,
mode="constant",
constant_values=fill_value,
Expand Down Expand Up @@ -1569,7 +1569,7 @@ def pad(
pad_option_kwargs["reflect_type"] = reflect_type

array = np.pad(
self.data.astype(dtype, copy=False),
duck_array_ops.astype(self.data, dtype, copy=False),
pad_width_by_index,
mode=mode,
**pad_option_kwargs,
Expand Down Expand Up @@ -2437,7 +2437,7 @@ def rolling_window(
"""
if fill_value is dtypes.NA: # np.nan is passed
dtype, fill_value = dtypes.maybe_promote(self.dtype)
var = self.astype(dtype, copy=False)
var = duck_array_ops.astype(self, dtype, copy=False)
else:
dtype = self.dtype
var = self
Expand Down

0 comments on commit 9cd9078

Please sign in to comment.