diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 2e10d2a1..1b929c6a 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -165,6 +165,18 @@ def elemwise(func, *args, dtype=None): def index(x, key): + if not isinstance(key, tuple): + key = (key,) + + # Remove None values, to be filled in with expand_dims at end + where_none = [i for i, ind in enumerate(key) if ind is None] + for i, a in enumerate(where_none): + n = sum(isinstance(ind, Integral) for ind in key[:a]) + if n: + where_none[i] -= n + key = tuple(ind for ind in key if ind is not None) + + # Replace ellipsis with slices selection = replace_ellipsis(key, x.shape) # Use a Zarr BasicIndexer just to find the resulting array shape @@ -185,14 +197,12 @@ def index(x, key): else: raise NotImplementedError(f"Index not supported: {key}") - from cubed.core.ops import map_direct - # memory allocated by reading one chunk from input array # note that although the output chunk will overlap multiple input chunks, zarr will # read the chunks in series, reusing the buffer extra_required_mem = x.chunkmem - return map_direct( + out = map_direct( func, x, shape=shape, @@ -203,6 +213,13 @@ def index(x, key): offsets=offsets, ) + for axis in where_none: + from cubed.array_api.manipulation_functions import expand_dims + + out = expand_dims(out, axis=axis) + + return out + def _read_index_chunk(x, *arrays, source_chunks=None, offsets=None, block_id=None): array = arrays[0] diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 354b4c1b..3b103b6f 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -162,13 +162,13 @@ def test_negative(spec, executor): # Indexing -@pytest.mark.parametrize("i", [6]) +@pytest.mark.parametrize("i", [6, (6, None)]) def test_index_1d(spec, i): a = xp.arange(12, chunks=(4,), spec=spec) assert_array_equal(a[i].compute(), np.arange(12)[i]) -@pytest.mark.parametrize("i", [(2, 3)]) +@pytest.mark.parametrize("i", [(2, 3), (None, 2, 3)]) def test_index_2d(spec, i): a = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], @@ -186,22 +186,26 @@ def test_slice_1d(spec, sl): @pytest.mark.parametrize( - "sl0, sl1", + "sl", [ (slice(None), slice(2, 4)), (slice(3), slice(2, None)), (slice(1, None), slice(4)), (slice(1, 3), slice(None)), + (slice(None), slice(2, 4)), + (None, slice(None), slice(2, 4)), # add a new dimension + (slice(None), None, slice(2, 4)), # add a new dimension + (slice(None), slice(2, 4), None), # add a new dimension ], ) -def test_slice_2d(spec, sl0, sl1): +def test_slice_2d(spec, sl): a = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], chunks=(2, 2), spec=spec, ) x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) - assert_array_equal(a[sl0, sl1].compute(), x[sl0, sl1]) + assert_array_equal(a[sl].compute(), x[sl]) def test_slice_unsupported_step(spec):