diff --git a/datamapplot/create_plots.py b/datamapplot/create_plots.py index 9a05f08..af087ab 100644 --- a/datamapplot/create_plots.py +++ b/datamapplot/create_plots.py @@ -22,7 +22,7 @@ def create_plot( data_map_coords, - labels, + labels=None, *, title=None, sub_title=None, @@ -155,28 +155,34 @@ def create_plot( The axes contained within the figure that the plot is rendered to. """ - cluster_label_vector = np.asarray(labels) - unique_non_noise_labels = [ - label for label in np.unique(cluster_label_vector) if label != noise_label - ] - if use_medoids: - label_locations = np.asarray( - [ - medoid(data_map_coords[cluster_label_vector == i]) - for i in unique_non_noise_labels - ] - ) + if labels is None: + label_locations = np.zeros((0, 2), dtype=np.float32) + label_text = [] + cluster_label_vector = np.full(data_map_coords.shape[0], "Unlabelled", dtype=object) + unique_non_noise_labels = [] else: - label_locations = np.asarray( - [ - data_map_coords[cluster_label_vector == i].mean(axis=0) - for i in unique_non_noise_labels - ] - ) - label_text = [ - textwrap.fill(x, width=label_wrap_width, break_long_words=False) - for x in unique_non_noise_labels - ] + cluster_label_vector = np.asarray(labels) + unique_non_noise_labels = [ + label for label in np.unique(cluster_label_vector) if label != noise_label + ] + if use_medoids: + label_locations = np.asarray( + [ + medoid(data_map_coords[cluster_label_vector == i]) + for i in unique_non_noise_labels + ] + ) + else: + label_locations = np.asarray( + [ + data_map_coords[cluster_label_vector == i].mean(axis=0) + for i in unique_non_noise_labels + ] + ) + label_text = [ + textwrap.fill(x, width=label_wrap_width, break_long_words=False) + for x in unique_non_noise_labels + ] if highlight_labels is not None: highlight_labels = [ textwrap.fill(x, width=label_wrap_width, break_long_words=False) @@ -225,7 +231,7 @@ def create_plot( label_colors = [label_color_map[x] for x in unique_non_noise_labels] - if color_label_text: + if color_label_text and len(label_colors) > 0: # Darken and reduce chroma of label colors to get text labels if darkmode: label_text_colors = pastel_palette(label_colors) @@ -421,21 +427,28 @@ def create_interactive_plot( """ if len(label_layers) == 0: - return None - - label_dataframe = pd.concat( - [ - label_text_and_polygon_dataframes( - labels, - data_map_coords, - noise_label=noise_label, - use_medoids=use_medoids, - cluster_polygons=cluster_boundary_polygons, - alpha=polygon_alpha, - ) - for labels in label_layers - ] - ) + label_dataframe = pd.DataFrame( + { + "x": [data_map_coords.T[0].mean()], + "y": [data_map_coords.T[1].mean()], + "label": [""], + "size": [np.power(data_map_coords.shape[0], 0.25)], + } + ) + else: + label_dataframe = pd.concat( + [ + label_text_and_polygon_dataframes( + labels, + data_map_coords, + noise_label=noise_label, + use_medoids=use_medoids, + cluster_polygons=cluster_boundary_polygons, + alpha=polygon_alpha, + ) + for labels in label_layers + ] + ) if label_color_map is None: if cmap is None: diff --git a/datamapplot/interactive_rendering.py b/datamapplot/interactive_rendering.py index fabccb4..ad94a5a 100644 --- a/datamapplot/interactive_rendering.py +++ b/datamapplot/interactive_rendering.py @@ -751,9 +751,12 @@ def render_html( # Compute text scaling size_range = label_dataframe["size"].max() - label_dataframe["size"].min() - label_dataframe["size"] = ( - label_dataframe["size"] - label_dataframe["size"].min() - ) * ((max_fontsize - min_fontsize) / size_range) + min_fontsize + if size_range > 0: + label_dataframe["size"] = ( + label_dataframe["size"] - label_dataframe["size"].min() + ) * ((max_fontsize - min_fontsize) / size_range) + min_fontsize + else: + label_dataframe["size"] = (max_fontsize + min_fontsize) / 2.0 # Prep data for inlining or storage if enable_search: diff --git a/datamapplot/palette_handling.py b/datamapplot/palette_handling.py index 256a45e..d544ad4 100644 --- a/datamapplot/palette_handling.py +++ b/datamapplot/palette_handling.py @@ -11,6 +11,9 @@ def palette_from_datamap( theta_range=np.pi / 16, radius_weight_power=1.0, ): + if label_locations.shape[0] == 0: + return [] + data_center = np.asarray( umap_coords.min(axis=0) + (umap_coords.max(axis=0) - umap_coords.min(axis=0)) / 2 @@ -143,6 +146,9 @@ def palette_from_cmap_and_datamap( theta_range=np.pi / 16, radius_weight_power=1.0, ): + if label_locations.shape[0] == 0: + return [cmap(0.5)] + endpoints = cmap((0.0, 1.0)) endpoint_distance = np.sum((endpoints[0] - endpoints[1]) ** 2) if endpoint_distance < 0.05: diff --git a/datamapplot/plot_rendering.py b/datamapplot/plot_rendering.py index ed1e0cd..f14b548 100644 --- a/datamapplot/plot_rendering.py +++ b/datamapplot/plot_rendering.py @@ -402,132 +402,142 @@ def render_plot( # Find initial placements for text, fix any line crossings, then optimize placements ax.autoscale_view() - label_text_locations = initial_text_location_placement( - label_locations, - base_radius=label_base_radius, - theta_stretch=label_direction_bias, - ) - fix_crossings(label_text_locations, label_locations) + if label_locations.shape[0] > 0: + label_text_locations = initial_text_location_placement( + label_locations, + base_radius=label_base_radius, + theta_stretch=label_direction_bias, + ) + fix_crossings(label_text_locations, label_locations) + + font_scale_factor = np.sqrt(figsize[0] * figsize[1]) + if label_font_size is None: + font_size = estimate_font_size( + label_text_locations, + label_text, + 0.9 * font_scale_factor, + fontfamily=fontfamily, + linespacing=label_linespacing, + ax=ax, + ) + else: + font_size = label_font_size - font_scale_factor = np.sqrt(figsize[0] * figsize[1]) - if label_font_size is None: - font_size = estimate_font_size( + # Ensure we can look up labels for highlighting + if highlight_labels is not None: + highlight = set(highlight_labels) + else: + highlight = set([]) + + label_text_locations = adjust_text_locations( label_text_locations, + label_locations, label_text, - 0.9 * font_scale_factor, fontfamily=fontfamily, + font_size=font_size, linespacing=label_linespacing, + highlight=highlight, + highlight_label_keywords=highlight_label_keywords, ax=ax, + expand=(label_margin_factor, label_margin_factor), + label_size_adjustments=label_size_adjustments, ) - else: - font_size = label_font_size - - # Ensure we can look up labels for highlighting - if highlight_labels is not None: - highlight = set(highlight_labels) - else: - highlight = set([]) - - label_text_locations = adjust_text_locations( - label_text_locations, - label_locations, - label_text, - fontfamily=fontfamily, - font_size=font_size, - linespacing=label_linespacing, - highlight=highlight, - highlight_label_keywords=highlight_label_keywords, - ax=ax, - expand=(label_margin_factor, label_margin_factor), - label_size_adjustments=label_size_adjustments, - ) - # Build highlight boxes - if ( - "bbox" in highlight_label_keywords - and highlight_label_keywords["bbox"] is not None - ): - base_bbox_keywords = highlight_label_keywords["bbox"] - else: - base_bbox_keywords = None - - # Add the annotations to the plot - texts = [] - for i in range(label_locations.shape[0]): - if base_bbox_keywords is not None: - bbox_keywords = dict(base_bbox_keywords.items()) - if "fc" not in base_bbox_keywords: - if highlight_colors is not None: - bbox_keywords["fc"] = highlight_colors[i][:7] + "33" - else: - bbox_keywords["fc"] = "#cccccc33" if darkmode else "#33333333" - if "ec" not in base_bbox_keywords: - bbox_keywords["ec"] = "none" - else: - bbox_keywords = None - - if label_text_colors: - text_color = label_text_colors[i] - elif darkmode: - text_color = "white" + # Build highlight boxes + if ( + "bbox" in highlight_label_keywords + and highlight_label_keywords["bbox"] is not None + ): + base_bbox_keywords = highlight_label_keywords["bbox"] else: - text_color = "black" - - if label_arrow_colors: - arrow_color = label_arrow_colors[i] - elif darkmode: - arrow_color = "#dddddd" - else: - arrow_color = "#333333" - - texts.append( - ax.annotate( - label_text[i], - label_locations[i], - xytext=label_text_locations[i], - ha="center", - ma="center", - va="center", - linespacing=label_linespacing, - fontfamily=fontfamily, - arrowprops={ - "arrowstyle": "-", - "linewidth": 0.5, - "color": arrow_color, - **arrowprops, - }, - fontsize=( - highlight_label_keywords.get("fontsize", font_size) + base_bbox_keywords = None + + # Add the annotations to the plot + texts = [] + for i in range(label_locations.shape[0]): + if base_bbox_keywords is not None: + bbox_keywords = dict(base_bbox_keywords.items()) + if "fc" not in base_bbox_keywords: + if highlight_colors is not None: + bbox_keywords["fc"] = highlight_colors[i][:7] + "33" + else: + bbox_keywords["fc"] = "#cccccc33" if darkmode else "#33333333" + if "ec" not in base_bbox_keywords: + bbox_keywords["ec"] = "none" + else: + bbox_keywords = None + + if label_text_colors: + text_color = label_text_colors[i] + elif darkmode: + text_color = "white" + else: + text_color = "black" + + if label_arrow_colors: + arrow_color = label_arrow_colors[i] + elif darkmode: + arrow_color = "#dddddd" + else: + arrow_color = "#333333" + + texts.append( + ax.annotate( + label_text[i], + label_locations[i], + xytext=label_text_locations[i], + ha="center", + ma="center", + va="center", + linespacing=label_linespacing, + fontfamily=fontfamily, + arrowprops={ + "arrowstyle": "-", + "linewidth": 0.5, + "color": arrow_color, + **arrowprops, + }, + fontsize=( + highlight_label_keywords.get("fontsize", font_size) + if label_text[i] in highlight + else font_size + ) + + ( + label_size_adjustments[i] + if label_size_adjustments is not None + else 0.0 + ), + bbox=bbox_keywords if label_text[i] in highlight else None, + color=text_color, + fontweight=highlight_label_keywords.get("fontweight", "normal") if label_text[i] in highlight - else font_size + else "normal", ) - + ( - label_size_adjustments[i] - if label_size_adjustments is not None - else 0.0 - ), - bbox=bbox_keywords if label_text[i] in highlight else None, - color=text_color, - fontweight=highlight_label_keywords.get("fontweight", "normal") - if label_text[i] in highlight - else "normal", ) - ) - # Ensure we have plot bounds that meet the newly place annotations - coords = get_2d_coordinates(texts) - x_min, y_min = ax.transData.inverted().transform( - (coords[:, [0, 2]].copy().min(axis=0)) - ) - x_max, y_max = ax.transData.inverted().transform( - (coords[:, [1, 3]].copy().max(axis=0)) - ) - width = x_max - x_min - height = y_max - y_min - x_min -= 0.05 * width - x_max += 0.05 * width - y_min -= 0.05 * height - y_max += 0.05 * height + # Ensure we have plot bounds that meet the newly place annotations + coords = get_2d_coordinates(texts) + x_min, y_min = ax.transData.inverted().transform( + (coords[:, [0, 2]].copy().min(axis=0)) + ) + x_max, y_max = ax.transData.inverted().transform( + (coords[:, [1, 3]].copy().max(axis=0)) + ) + width = x_max - x_min + height = y_max - y_min + x_min -= 0.05 * width + x_max += 0.05 * width + y_min -= 0.05 * height + y_max += 0.05 * height + else: + x_min, y_min = data_map_coords.min(axis=0) + x_max, y_max = data_map_coords.max(axis=0) + width = x_max - x_min + height = y_max - y_min + x_min -= 0.05 * width + x_max += 0.05 * width + y_min -= 0.05 * height + y_max += 0.05 * height # decorate the plot ax.set(xticks=[], yticks=[])