diff --git a/changelog/241.bugfix.rst b/changelog/241.bugfix.rst new file mode 100644 index 000000000..6f81d27d2 --- /dev/null +++ b/changelog/241.bugfix.rst @@ -0,0 +1,4 @@ +Changes behavior of NDCubeSequence slicing. Previously, a slice item of interval +length 1 would cause an NDCube object to be returned. Now an NDCubeSequence made +up of 1 NDCube is returned. This is consistent with how interval length 1 slice +items slice arrays. diff --git a/ndcube/ndcube_sequence.py b/ndcube/ndcube_sequence.py index 3835f3aa0..2622075e9 100644 --- a/ndcube/ndcube_sequence.py +++ b/ndcube/ndcube_sequence.py @@ -1,3 +1,6 @@ +import copy +import numbers + import numpy as np import astropy.units as u @@ -80,8 +83,12 @@ def cube_like_world_axis_physical_types(self): return self.data[0].world_axis_physical_types def __getitem__(self, item): - if len(self.dimensions) == 1: + if isinstance(item, numbers.Integral): return self.data[item] + elif isinstance(item, slice): + result = copy.deepcopy(self) + result.data = self.data[item] + return result else: return utils.sequence.slice_sequence(self, item) diff --git a/ndcube/tests/test_ndcubesequence.py b/ndcube/tests/test_ndcubesequence.py index 63824b1ba..39f868b53 100644 --- a/ndcube/tests/test_ndcubesequence.py +++ b/ndcube/tests/test_ndcubesequence.py @@ -116,12 +116,14 @@ (seq[2], NDCube), (seq[3], NDCube), (seq[0:1], NDCubeSequence), + (seq[0:1, 0:2], NDCubeSequence), + (seq[0:1, 0], NDCubeSequence), (seq[1:3], NDCubeSequence), (seq[0:2], NDCubeSequence), (seq[slice(0, 2)], NDCubeSequence), (seq[slice(0, 3)], NDCubeSequence), ]) -def test_slice_first_index_sequence(test_input, expected): +def test_slice_first_index_sequence_type(test_input, expected): assert isinstance(test_input, expected) @@ -132,7 +134,7 @@ def test_slice_first_index_sequence(test_input, expected): (seq[slice(0, 2)], 2 * u.pix), (seq[slice(0, 3)], 3 * u.pix), ]) -def test_slice_first_index_sequence(test_input, expected): +def test_slice_first_index_sequence_dimensions(test_input, expected): assert test_input.dimensions[0] == expected diff --git a/ndcube/utils/sequence.py b/ndcube/utils/sequence.py index b85ccebb6..c16b1510c 100644 --- a/ndcube/utils/sequence.py +++ b/ndcube/utils/sequence.py @@ -142,8 +142,13 @@ def _get_sequence_items_from_slice_item(slice_item, n_cubes, cube_item=slice(Non stop = -1 else: stop = no_none_slice.stop - sequence_items = [SequenceItem(i, cube_item) - for i in range(no_none_slice.start, stop, no_none_slice.step)] + # If slice has interval length 1, make sequence index length 1 slice to + # ensure dimension is not dropped in accordance with slicing convention. + if abs(stop - no_none_slice.start) == 1: + sequence_items = [SequenceItem(slice_item, cube_item)] + else: + sequence_items = [SequenceItem(i, cube_item) + for i in range(no_none_slice.start, stop, no_none_slice.step)] return sequence_items @@ -201,12 +206,19 @@ def slice_sequence_by_sequence_items(cubesequence, sequence_items): """ result = deepcopy(cubesequence) if len(sequence_items) == 1: - return result.data[sequence_items[0].sequence_index][sequence_items[0].cube_item] + # If sequence item is interval length 1 slice, ensure an NDCubeSequence + # is returned in accordance with slicing convention. + # Due to code up to this point, if sequence item is a slice, it can only + # be an interval length 1 slice. + if isinstance(sequence_items[0].sequence_index, slice): + result.data = [result.data[sequence_items[0].sequence_index.start][sequence_items[0].cube_item]] + else: + result = result.data[sequence_items[0].sequence_index][sequence_items[0].cube_item] else: data = [result.data[sequence_item.sequence_index][sequence_item.cube_item] for sequence_item in sequence_items] result.data = data - return result + return result def _index_sequence_as_cube(cubesequence, item):