diff --git a/ndcube/mixins/plotting.py b/ndcube/mixins/plotting.py index c6b1424fd..8d488667a 100644 --- a/ndcube/mixins/plotting.py +++ b/ndcube/mixins/plotting.py @@ -61,6 +61,12 @@ def plot(self, axes=None, plot_axis_indices=None, axes_coordinates=None, same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. + A str entry in axes_coordinates signifies that an extra_coord will be used for the axis's coordinates. + The str must be a valid name of an extra_coord that corresponds to the same axis to which it is applied in the plot. + + """ # If old API is used, convert to new API. @@ -237,7 +243,7 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None else: # Else manually set axes x and y values based on user's input for axes_coordinates. new_axes_coordinates, new_axis_units, default_labels = \ - self._derive_axes_coordinates(axes_coordinates, axes_units) + self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape) # Initialize axes object and set values along axis. fig, ax = plt.subplots(1, 1) # Since we can't assume the x-axis will be uniform, create NonUniformImage @@ -291,6 +297,10 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None, same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. + A str entry in axes_coordinates signifies that an extra_coord will be used for the axis's coordinates. + The str must be a valid name of an extra_coord that corresponds to the same axis to which it is applied in the plot. """ # For convenience in inserting dummy variables later, ensure @@ -311,16 +321,20 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None, data = np.ma.masked_array(data, self.mask) # If axes_coordinates not provided generate an ImageAnimatorWCS plot # using NDCube's wcs object. + new_axes_coordinates, new_axes_units, default_labels = \ + self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape, edges=True) + if (axes_coordinates[plot_axis_indices[0]] is None and axes_coordinates[plot_axis_indices[1]] is None): + # If there are missing axes in WCS object, add corresponding dummy axes to data. if data.ndim < self.wcs.naxis: new_shape = list(data.shape) for i in np.arange(self.wcs.naxis)[self.missing_axes[::-1]]: new_shape.insert(i, 1) # Also insert dummy coordinates and units. - axes_coordinates.insert(i, None) - axes_units.insert(i, None) + new_axes_coordinates.insert(i, None) + new_axes_units.insert(i, None) # Iterate plot_axis_indices if neccessary for j, pai in enumerate(plot_axis_indices): if pai >= i: @@ -329,16 +343,23 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None, data = data.reshape(new_shape) # Generate plot ax = ImageAnimatorWCS(data, wcs=self.wcs, image_axes=plot_axis_indices, - unit_x_axis=axes_units[plot_axis_indices[0]], - unit_y_axis=axes_units[plot_axis_indices[1]], - axis_ranges=axes_coordinates, **kwargs) + unit_x_axis=new_axes_units[plot_axis_indices[0]], + unit_y_axis=new_axes_units[plot_axis_indices[1]], + axis_ranges=new_axes_coordinates, **kwargs) + + # Set the labels of the plot + ax.axes.coords[0].set_axislabel(self.wcs.world_axis_physical_types[plot_axis_indices[0]]) + ax.axes.coords[1].set_axislabel(self.wcs.world_axis_physical_types[plot_axis_indices[1]]) + # If one of the plot axes is set manually, produce a basic ImageAnimator object. else: - new_axes_coordinates, new_axes_units, default_labels = \ - self._derive_axes_coordinates(axes_coordinates, axes_units) # If axis labels not set by user add to kwargs. ax = ImageAnimator(data, image_axes=plot_axis_indices, axis_ranges=new_axes_coordinates, **kwargs) + + # Add the labels of the plot + ax.axes.set_xlabel(default_labels[plot_axis_indices[0]]) + ax.axes.set_ylabel(default_labels[plot_axis_indices[1]]) return ax def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None, @@ -371,9 +392,7 @@ def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None, else: unit_x_axis = None # Put xdata back into axes_coordinates as a masked array. - if len(xdata.shape) > 1: - # Since LineAnimator currently only accepts 1-D arrays for the x-axis, collapse xdata # to single dimension by taking mean along non-plotting axes. index = utils.wcs.get_dependent_data_axes(self.wcs, plot_axis_index, self.missing_axes) @@ -382,7 +401,6 @@ def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None, index = np.delete(index, reduce_axis) # Reduce the data by taking mean xdata = np.mean(xdata, axis=tuple(index)) - axes_coordinates[plot_axis_index] = xdata # Set default x label default_xlabel = "{0} [{1}]".format(xname, unit_x_axis) @@ -406,7 +424,7 @@ def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None, ylabel="Data [{0}]".format(data_unit), **kwargs) return ax - def _derive_axes_coordinates(self, axes_coordinates, axes_units): + def _derive_axes_coordinates(self, axes_coordinates, axes_units, data_shape, edges=False): new_axes_coordinates = [] new_axes_units = [] default_labels = [] @@ -414,13 +432,27 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units): for i, axis_coordinate in enumerate(axes_coordinates): # If axis coordinate is None, derive axis values from WCS. if axis_coordinate is None: - # N.B. This assumes axes are independent. Fix this before merging!!! - new_axis_coordinate = self.axis_world_coords(i) + + # If the new_axis_coordinate is not independent, i.e. dimension is >2D + # and not equal to dimension of data, then the new_axis_coordinate must + # be reduced to a 1D ndarray by taking the mean along all non-plotting axes. + new_axis_coordinate = self.axis_world_coords(i, edges=edges) axis_label_text = self.world_axis_physical_types[i] + # If the shape of the data is not 1, or all the axes are not dependent + if new_axis_coordinate.ndim != 1 and new_axis_coordinate.ndim != len(data_shape): + index = utils.wcs.get_dependent_data_axes(self.wcs, i, self.missing_axes) + reduce_axis = np.where(index == np.array([i]))[0] + + index = np.delete(index, reduce_axis) + # Reduce the data by taking mean + new_axis_coordinate = np.mean(new_axis_coordinate, axis=tuple(index)) + elif isinstance(axis_coordinate, str): # If axis coordinate is a string, derive axis values from # corresponding extra coord. - new_axis_coordinate = self.extra_coords[axis_coordinate]["value"] + # Calculate edge value if required + new_axis_coordinate = _get_extra_coord_edges(self.extra_coords[axis_coordinate]["value"]) if edges else \ + self.extra_coords[axis_coordinate]["value"] axis_label_text = axis_coordinate else: # Else user must have manually set the axis coordinates. @@ -445,13 +477,14 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units): new_axis_unit = None else: raise TypeError(INVALID_UNIT_SET_MESSAGE) + # Derive default axis label - if type(new_axis_coordinate[0]) is datetime.datetime: + if type(new_axis_coordinate) is datetime.datetime: if axis_label_text == default_label_text: - default_label = "{0}".format(new_axis_coordinate[0].strftime("%Y/%m/%d %H:%M")) + default_label = "{0}".format(new_axis_coordinate.strftime("%Y/%m/%d %H:%M")) else: default_label = "{0} [{1}]".format( - axis_label_text, new_axis_coordinate[0].strftime("%Y/%m/%d %H:%M")) + axis_label_text, new_axis_coordinate.strftime("%Y/%m/%d %H:%M")) else: default_label = "{0} [{1}]".format(axis_label_text, new_axis_unit) # Append new coordinates, units and labels to output list. diff --git a/ndcube/mixins/sequence_plotting.py b/ndcube/mixins/sequence_plotting.py index 048a66929..4146e9d73 100644 --- a/ndcube/mixins/sequence_plotting.py +++ b/ndcube/mixins/sequence_plotting.py @@ -62,6 +62,10 @@ def plot(self, axes=None, plot_axis_indices=None, `None` (implies derive the coordinates from the WCS objects), an `astropy.units.Quantity` or a `numpy.ndarray` of coordinates for each pixel, or a `str` denoting a valid extra coordinate. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. + A str entry in axes_coordinates signifies that an extra_coord will be used for the axis's coordinates. + The str must be a valid name of an extra_coord that corresponds to the same axis to which it is applied in the plot. axes_units: `None or `list` of `None`, `astropy.units.Unit` and/or `str` If None units derived from the WCS objects will be used for all axes. @@ -168,6 +172,10 @@ def plot_as_cube(self, axes=None, plot_axis_indices=None, `None` (implies derive the coordinates from the WCS objects), an `astropy.units.Quantity` or a `numpy.ndarray` of coordinates for each pixel, or a `str` denoting a valid extra coordinate. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. + A str entry in axes_coordinates signifies that an extra_coord will be used for the axis's coordinates. + The str must be a valid name of an extra_coord that corresponds to the same axis to which it is applied in the plot. axes_units: `None or `list` of `None`, `astropy.units.Unit` and/or `str` If None units derived from the WCS objects will be used for all axes. @@ -254,6 +262,10 @@ def _plot_1D_sequence(self, axes_coordinates=None, each pixel along the x-axis. If a `str`, denotes the extra coordinate to be used. The extra coordinate must correspond to the sequence axis. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. + A str entry in axes_coordinates signifies that an extra_coord will be used for the axis's coordinates. + The str must be a valid name of an extra_coord that corresponds to the same axis to which it is applied in the plot. axes_units: `astropy.unit.Unit` or valid unit `str` or length 1 `list` of those types. Unit in which X-axis should be displayed. Must be compatible with the unit of @@ -683,6 +695,8 @@ class ImageAnimatorNDCubeSequence(ImageAnimatorWCS): same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axis_ranges should be an array of + pixel_edges. interval: `int` Animation interval in ms @@ -778,6 +792,8 @@ class ImageAnimatorCubeLikeNDCubeSequence(ImageAnimatorWCS): same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axis_ranges should be an array of + pixel_edges. interval: `int` Animation interval in ms @@ -895,6 +911,8 @@ class LineAnimatorNDCubeSequence(LineAnimator): same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axis_ranges should be an array of + pixel_edges. interval: `int` Animation interval in ms @@ -1129,6 +1147,8 @@ class LineAnimatorCubeLikeNDCubeSequence(LineAnimator): same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + The physical coordinates expected by axis_ranges should be an array of + pixel_edges. interval: `int` Animation interval in ms @@ -1349,6 +1369,8 @@ def _prep_axes_kwargs(naxis, plot_axis_indices, axes_coordinates, axes_units): axes_coordinates: `None` or `list` of `None` `astropy.units.Quantity` `numpy.ndarray` `str` Length of list equals number of sequence axes. + The physical coordinates expected by axes_coordinates should be an array of + pixel_edges. axes_units: None or `list` of `None` `astropy.units.Unit` or `str` Length of list equals number of sequence axes. diff --git a/ndcube/tests/test_plotting.py b/ndcube/tests/test_plotting.py index 327267fc2..f3c9a7f28 100644 --- a/ndcube/tests/test_plotting.py +++ b/ndcube/tests/test_plotting.py @@ -33,6 +33,7 @@ data = np.array([[[1, 2, 3, 4], [2, 4, 5, 3], [0, -1, 2, 3]], [[2, 4, 5, 1], [10, 5, 2, 2], [10, 3, 3, 0]]]) + uncertainty = np.sqrt(data) mask_cube = data < 0