diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 0f7689bd..6a639659 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -624,10 +624,13 @@ def reduction( chunk_mem = chunk_memory(intermediate_dtype, result.chunksize) for i, s in enumerate(result.shape): if i in axis: + assert result.chunksize[i] == 1 # result of reduction if len(axis) > 1: - # TODO: make sure chunk size doesn't exceed max_mem for multi-axis reduction - target_chunks[i] = s + # multi-axis: don't exceed original chunksize in any reduction axis + # TODO: improve to use up to max_mem + target_chunks[i] = min(s, x.chunksize[i]) else: + # single axis: see how many result chunks fit in max_mem # factor of 4 is memory for {compressed, uncompressed} x {input, output} (see rechunk.py) target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4) if target_chunk_size <= 1: