Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes the bug associated with axes_coordinates #189

Merged
merged 13 commits into from
Aug 8, 2019
37 changes: 21 additions & 16 deletions ndcube/mixins/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,7 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None,
plot_axis_indices = [i if i >= 0 else self.data.ndim + i for i in plot_axis_indices]
# If axes kwargs not set by user, set them as list of Nones for
# each axis for consistent behaviour.
# Or if the axes_coordinates values are None for plot_axis_indices indexes
# This ensures no extraneous values from the user.
if axes_coordinates is None or (axes_coordinates[plot_axis_indices[0]] is None and
axes_coordinates[plot_axis_indices[1]] is None):
if axes_coordinates is None:
axes_coordinates = [None] * self.data.ndim
if axes_units is None:
axes_units = [None] * self.data.ndim
Expand All @@ -314,16 +311,20 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None,
data = np.ma.masked_array(data, self.mask)
# If axes_coordinates not provided generate an ImageAnimatorWCS plot
# using NDCube's wcs object.
new_axes_coordinates, new_axes_units, default_labels = \
self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape, edges=True)

if (axes_coordinates[plot_axis_indices[0]] is None and
axes_coordinates[plot_axis_indices[1]] is None):

# If there are missing axes in WCS object, add corresponding dummy axes to data.
if data.ndim < self.wcs.naxis:
new_shape = list(data.shape)
for i in np.arange(self.wcs.naxis)[self.missing_axes[::-1]]:
new_shape.insert(i, 1)
# Also insert dummy coordinates and units.
axes_coordinates.insert(i, None)
axes_units.insert(i, None)
new_axes_coordinates.insert(i, None)
new_axes_units.insert(i, None)
# Iterate plot_axis_indices if neccessary
for j, pai in enumerate(plot_axis_indices):
if pai >= i:
Expand All @@ -332,16 +333,20 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None,
data = data.reshape(new_shape)
# Generate plot
ax = ImageAnimatorWCS(data, wcs=self.wcs, image_axes=plot_axis_indices,
unit_x_axis=axes_units[plot_axis_indices[0]],
unit_y_axis=axes_units[plot_axis_indices[1]],
axis_ranges=axes_coordinates, **kwargs)
unit_x_axis=new_axes_units[plot_axis_indices[0]],
unit_y_axis=new_axes_units[plot_axis_indices[1]],
axis_ranges=new_axes_coordinates, **kwargs)

# ax.axes.coords[0].set_axislabel(self.wcs.world_axis_physical_types[plot_axis_indices[0]])
yashrsharma44 marked this conversation as resolved.
Show resolved Hide resolved
# ax.axes.coords[1].set_axislabel(self.wcs.world_axis_physical_types[plot_axis_indices[1]])
# If one of the plot axes is set manually, produce a basic ImageAnimator object.
else:
new_axes_coordinates, new_axes_units, default_labels = \
self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape, edges=True)
# If axis labels not set by user add to kwargs.
ax = ImageAnimator(data, image_axes=plot_axis_indices,
axis_ranges=new_axes_coordinates, **kwargs)
# Add the labels of the plot
# ax.axes.set_xlabel(default_labels[plot_axis_indices[0]])
yashrsharma44 marked this conversation as resolved.
Show resolved Hide resolved
# ax.axes.set_ylabel(default_labels[plot_axis_indices[1]])
return ax

def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None,
Expand Down Expand Up @@ -374,9 +379,7 @@ def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None,
else:
unit_x_axis = None
# Put xdata back into axes_coordinates as a masked array.

if len(xdata.shape) > 1:

# Since LineAnimator currently only accepts 1-D arrays for the x-axis, collapse xdata
# to single dimension by taking mean along non-plotting axes.
index = utils.wcs.get_dependent_data_axes(self.wcs, plot_axis_index, self.missing_axes)
Expand All @@ -385,7 +388,6 @@ def _animate_cube_1D(self, plot_axis_index=-1, axes_coordinates=None,
index = np.delete(index, reduce_axis)
# Reduce the data by taking mean
xdata = np.mean(xdata, axis=tuple(index))

axes_coordinates[plot_axis_index] = xdata
# Set default x label
default_xlabel = "{0} [{1}]".format(xname, unit_x_axis)
Expand Down Expand Up @@ -419,7 +421,9 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units, data_shape, edg
if axis_coordinate is None:
# Fix: We would downscale the dependent data into the shape of the axes.
yashrsharma44 marked this conversation as resolved.
Show resolved Hide resolved
xdata = self.axis_world_coords(i, edges=edges)
yashrsharma44 marked this conversation as resolved.
Show resolved Hide resolved
if xdata.ndim != data_shape[i]:

# If the shape of the data is not 1, or all the axes are not dependent
if xdata.ndim != 1 and xdata.ndim != len(data_shape):
axis_label_text = self.world_axis_physical_types[i]

index = utils.wcs.get_dependent_data_axes(self.wcs, i, self.missing_axes)
Expand All @@ -428,7 +432,8 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units, data_shape, edg
index = np.delete(index, reduce_axis)
# Reduce the data by taking mean
new_axis_coordinate = np.mean(xdata, axis=tuple(index))
new_axis_coordinate = xdata
else:
new_axis_coordinate = xdata
elif isinstance(axis_coordinate, str):
# If axis coordinate is a string, derive axis values from
# corresponding extra coord.
Expand Down