diff --git a/trviz/main.py b/trviz/main.py index 464c9cf..d559ad4 100644 --- a/trviz/main.py +++ b/trviz/main.py @@ -37,6 +37,7 @@ def generate_trplot(self, # Figure parameters figure_size: Tuple[int, int] = None, + motif_map_size: Tuple[int, int] = None, output_name: str = None, dpi: int = 300, hide_xticks: bool = False, @@ -56,6 +57,13 @@ def generate_trplot(self, color_palette: str = None, colormap: ListedColormap = None, motif_style: str = 'box', + xtick_step: int = 1, + ytick_step: int = 1, + xtick_offset: int = 0, + ytick_offset: int = 0, + xlabel: str = None, + ylabel: str = None, + colored_motifs: List[str] = None, verbose: bool = True, ): """ @@ -104,6 +112,14 @@ def generate_trplot(self, :param color_palette: str = None, :param colormap: ListedColormap = None, :param motif_style: 'box' (default), 'arrow', or 'triangle'. + :param xtick_step: x tick step (default is 1) + :param ytick_step: y tick step (default is 1) + :param xtick_offset: x tick offset (default is 0) + :param ytick_offset: y tick offset (default is 0) + :param motif_map_size: motif map size + :param colored_motifs: only these motifs will be colored in the plot. Others will have the same color. + :param ylabel: y label in the plot + :param xlabel: x lable in the plot :param verbose: if true, output detailed information """ @@ -202,6 +218,13 @@ def generate_trplot(self, color_palette=color_palette, colormap=colormap, motif_style=motif_style, + xtick_step=xtick_step, + ytick_step=ytick_step, + xtick_offset=xtick_offset, + ytick_offset=ytick_offset, + xlabel=xlabel, + ylabel=ylabel, + colored_motifs=colored_motifs, ) # 6. Motif map @@ -210,4 +233,5 @@ def generate_trplot(self, self.visualizer.symbol_to_color, show_figure=show_figure, file_name=f"{output_dir}/{str(tr_id)}_motif_map.png", + figure_size=motif_map_size, ) diff --git a/trviz/visualizer.py b/trviz/visualizer.py index 5a5dfd2..f7aac02 100644 --- a/trviz/visualizer.py +++ b/trviz/visualizer.py @@ -1,7 +1,10 @@ from typing import List, Tuple, Dict +from collections import Counter import matplotlib.pyplot as plt -import numpy as np from matplotlib.colors import ListedColormap +from matplotlib.ticker import IndexLocator +import numpy as np +import distinctipy from trviz.utils import PRIVATE_MOTIF_LABEL @@ -27,13 +30,14 @@ def encode_tr_sequence(labeled_motifs): @staticmethod def _get_unique_labels(aligned_repeats): - unique_repeats = set() + unique_repeats = Counter() for rs in aligned_repeats: for r in rs: if r != '-': - unique_repeats.add(r) + unique_repeats[r] += 1 - return sorted(unique_repeats) + # sort by the frequency + return list(sorted(unique_repeats, key=unique_repeats.get, reverse=True)) @staticmethod def plot_motif_color_map(symbol_to_motif, motif_counter, symbol_to_color, file_name, @@ -63,10 +67,9 @@ def plot_motif_color_map(symbol_to_motif, motif_counter, symbol_to_color, file_n ax.set_aspect('equal') max_motif_length = len(max(symbol_to_motif.values(), key=len)) - # TODO: sort by the frequency # TODO: figure and font size yticklabels = [] - y_base_position = 0 + y_base_position = len(symbol_to_color) - 1 has_gap_in_symbol_to_color = False for (symbol, color) in symbol_to_color.items(): if symbol == '-': @@ -79,12 +82,9 @@ def plot_motif_color_map(symbol_to_motif, motif_counter, symbol_to_color, file_n linewidth=0, facecolor=color, edgecolor="white")) - y_base_position += 1 - - # text_position = [box_position[0] + 2.5 * box_width / 2, box_position[1] + box_height / 2 - 0.1] + y_base_position -= 1 motif = symbol_to_motif[symbol] text = f"{motif:<{max_motif_length + 2}}" + f"{motif_counter[motif]:>5}" - # ax.text(x=text_position[0], y=text_position[1], s=text, fontname="monospace") yticklabels.append(text) ax.yaxis.tick_right() @@ -93,7 +93,7 @@ def plot_motif_color_map(symbol_to_motif, motif_counter, symbol_to_color, file_n if has_gap_in_symbol_to_color: yticks_count = len(symbol_to_color) - 1 ax.set_yticks([y * box_size + box_size*0.5 for y in range(yticks_count)]) # minus 1 for gap, '-' - ax.set_yticklabels(yticklabels) + ax.set_yticklabels(yticklabels[::-1]) # font size if label_size is not None: @@ -146,6 +146,13 @@ def trplot(self, color_palette: str = None, colormap: ListedColormap = None, motif_style: str = "box", + xtick_step: int = 1, + ytick_step: int = 1, + xtick_offset: int = 0, + ytick_offset: int = 0, + xlabel: str = None, + ylabel: str = None, + colored_motifs: List[str] = None, debug: bool = False ): """ @@ -181,6 +188,10 @@ def trplot(self, :param color_palette: color palette name (see https://matplotlib.org/stable/users/explain/colors/colormaps.html) :param colormap: Matplotlib ListedColormap object :param motif_style: motif style. One of "box", "arrow", "triangle" + :param xtick_step: x tick step (default is 1) + :param ytick_step: y tick step (default is 1) + :param xtick_offset: x tick offset (default is 0) + :param ytick_offset: y tick offset (default is 0) :param debug: if true, print verbose information. """ @@ -196,7 +207,7 @@ def trplot(self, else: fig, ax_main = plt.subplots(figsize=figure_size) # width and height in inch - # Add clustering + # Sort by clustering and add dendrogram if needed if sort_by_clustering: if symbol_to_motif is None: raise ValueError("symbol_to_motif must be provided when sort_by_clustering is True") @@ -216,13 +227,124 @@ def trplot(self, print("No clustering") # Set symbol to color map + self.set_symbol_to_motif_map(aligned_labeled_repeats, alpha, color_palette, colored_motifs, colormap, + symbol_to_motif) + + self.draw_motifs(allele_as_row, ax_main, box_line_width, motif_marks, motif_style, no_edge, private_motif_color, + sorted_aligned_labeled_repeats, sorted_sample_ids) + + # Add another axis for sample labels + self.add_label_color_axis(aligned_labeled_repeats, allele_as_row, ax_main, box_line_width, sample_to_label, + sorted_aligned_labeled_repeats, sorted_sample_ids, xlabel_rotation, xlabel_size, + ylabel_rotation, ylabel_size) + + self.set_ticks_and_labels(aligned_labeled_repeats, allele_as_row, ax_main, hide_xticks, hide_yticks, + max_repeat_count, sample_to_label, sorted_aligned_labeled_repeats, sorted_sample_ids, + xlabel, xlabel_rotation, xlabel_size, xtick_offset, xtick_step, ylabel, + ylabel_rotation, ylabel_size, ytick_offset, ytick_step) + + if frame_on is None: # Set default + frame_on = {'top': False, 'bottom': True, 'right': False, 'left': True} + + ax_main.spines['top'].set_visible(frame_on['top']) + ax_main.spines['right'].set_visible(frame_on['right']) + ax_main.spines['bottom'].set_visible(frame_on['bottom']) + ax_main.spines['left'].set_visible(frame_on['left']) + + if output_name is not None: + if '.' not in output_name: + fig.savefig(f"{output_name}.pdf", dpi=dpi, bbox_inches='tight') + else: + fig.savefig(f"{output_name}", dpi=dpi, bbox_inches='tight') + else: + fig.savefig("test_trplot.png", dpi=dpi, bbox_inches='tight') + + if show_figure: + plt.show() + plt.close(fig) + + def set_symbol_to_motif_map(self, aligned_labeled_repeats, alpha, color_palette, colored_motifs, colormap, + symbol_to_motif): unique_labels = self._get_unique_labels(aligned_labeled_repeats) unique_label_count = len(unique_labels) - symbol_to_color = self.get_symbol_to_color_map(alpha, unique_label_count, unique_labels, - color_palette=color_palette, - colormap=colormap) - self.set_symbol_to_color_map(symbol_to_color) + if colored_motifs is None: + self.symbol_to_color = self.get_symbol_to_color_map(alpha, unique_label_count, unique_labels, + color_palette=color_palette, + colormap=colormap) + else: + # Only assign colors to unique motifs in the colored motifs + if colored_motifs is not None: + distinct_colors = distinctipy.get_colors(len(colored_motifs), pastel_factor=0.9, rng=777) + cmap = ListedColormap(distinct_colors) + if color_palette is not None: + cmap = plt.get_cmap(color_palette) + + for i, unique_motif in enumerate(colored_motifs): + for symbol, motif in symbol_to_motif.items(): + if unique_motif == motif: + self.symbol_to_color[symbol] = cmap(i) + + other_motifs = set(unique_labels) - set(self.symbol_to_color.keys()) + for motif in other_motifs: + self.symbol_to_color[motif] = 'grey' + + def set_ticks_and_labels(self, aligned_labeled_repeats, allele_as_row, ax_main, hide_xticks, hide_yticks, + max_repeat_count, sample_to_label, sorted_aligned_labeled_repeats, sorted_sample_ids, + xlabel, xlabel_rotation, xlabel_size, xtick_offset, xtick_step, ylabel, ylabel_rotation, + ylabel_size, ytick_offset, ytick_step): + if allele_as_row: + ax_main.set_ylim(top=len(sorted_aligned_labeled_repeats)) + if sample_to_label is None: + if ytick_step == 1: + ax_main.set_yticks([x + 0.5 for x in range(len(aligned_labeled_repeats))]) + ax_main.set_yticklabels(sorted_sample_ids, ha='right', rotation=ylabel_rotation) + else: + ax_main.yaxis.set_major_locator(IndexLocator(base=ytick_step, offset=ytick_offset)) + else: + ax_main.set_yticks([]) + if xtick_step == 1: + label_positions = [x + 0.5 for x in range(max_repeat_count)] + labels = [x for x in range(1, max_repeat_count + 1)] + ax_main.set_xticks(label_positions, labels=labels, rotation=xlabel_rotation) + ax_main.set_xticklabels(labels, rotation=xlabel_rotation) + else: + ax_main.xaxis.set_major_locator(IndexLocator(base=xtick_step, offset=xtick_offset)) + ax_main.set_xlim(right=max_repeat_count) + else: + ax_main.set_xlim(right=len(sorted_aligned_labeled_repeats)) + if sample_to_label is None: + if xtick_step == 1: + ax_main.set_xticks([x + 0.5 for x in range(len(aligned_labeled_repeats))]) + ax_main.set_xticklabels(sorted_sample_ids, rotation=xlabel_rotation) + else: + ax_main.xaxis.set_major_locator(IndexLocator(base=xtick_step, offset=xtick_offset)) + else: + ax_main.set_xticks([]) + + if ytick_step == 1: + label_positions = [y + 0.5 for y in range(max_repeat_count)] + labels = [y for y in range(1, max_repeat_count + 1)] + ax_main.set_yticks(label_positions, labels=labels, rotation=ylabel_rotation) + else: + ax_main.yaxis.set_major_locator(IndexLocator(base=ytick_step, offset=ytick_offset)) + ax_main.set_ylim(top=max_repeat_count) + ax_main.tick_params(axis='y', which='major', labelsize=ylabel_size) + ax_main.tick_params(axis='y', which='minor', labelsize=ylabel_size) + ax_main.tick_params(axis='x', which='major', labelsize=xlabel_size) + ax_main.tick_params(axis='x', which='minor', labelsize=xlabel_size) + if hide_xticks: + ax_main.set_xticks([]) + if hide_yticks: + ax_main.set_yticks([]) + # Label + if xlabel is not None: + ax_main.set_xlabel(xlabel, fontsize=xlabel_size) + if ylabel is not None: + ax_main.set_ylabel(ylabel, fontsize=ylabel_size) + + def draw_motifs(self, allele_as_row, ax_main, box_line_width, motif_marks, motif_style, no_edge, + private_motif_color, sorted_aligned_labeled_repeats, sorted_sample_ids): box_height = 1.0 box_width = 1.0 for allele_index, allele in enumerate(sorted_aligned_labeled_repeats): @@ -245,7 +367,7 @@ def trplot(self, motif_mark = motif_marks[sorted_sample_ids[allele_index]][motif_index] if motif_mark == 'I': # introns hatch_pattern = 'xxx' - fcolor = symbol_to_color[symbol] + fcolor = self.symbol_to_color[symbol] motif_index += 1 if symbol == PRIVATE_MOTIF_LABEL: @@ -258,85 +380,23 @@ def trplot(self, edgecolor=fcolor if no_edge else "white", hatch=hatch_pattern, )) elif motif_style == "arrow": - ax_main.add_patch(plt.Arrow(box_position[0], box_position[1]+box_height/2, box_width, 0, - width=1, - linewidth=box_line_width + 0.1, - facecolor=fcolor, - edgecolor=fcolor if no_edge else "white", - hatch=hatch_pattern, )) + ax_main.add_patch(plt.Arrow(box_position[0], box_position[1] + box_height / 2, box_width, 0, + width=1, + linewidth=box_line_width + 0.1, + facecolor=fcolor, + edgecolor=fcolor if no_edge else "white", + hatch=hatch_pattern, )) elif motif_style == "triangle": ax_main.add_patch(plt.Polygon([(box_position[0], box_position[1]), - (box_position[0], box_position[1] + box_height), - (box_position[0] + box_width, box_position[1] + box_height/2)], + (box_position[0], box_position[1] + box_height), + (box_position[0] + box_width, box_position[1] + box_height / 2)], color=fcolor, linewidth=box_line_width + 0.1, )) else: raise ValueError(f"Unknown motif style: {motif_style}") - - # Add colors based on sample labels - self.add_label_color_axis(aligned_labeled_repeats, allele_as_row, ax_main, box_line_width, sample_to_label, - sorted_aligned_labeled_repeats, sorted_sample_ids, xlabel_rotation, xlabel_size, - ylabel_rotation, ylabel_size) - - if allele_as_row: - ax_main.set_ylim(top=len(sorted_aligned_labeled_repeats)) - if sample_to_label is None: - ax_main.set_yticks([x + 0.5 for x in range(len(aligned_labeled_repeats))]) - ax_main.set_yticklabels(sorted_sample_ids, ha='right', rotation=ylabel_rotation) - else: - ax_main.set_yticks([]) - - label_positions = [x + 0.5 for x in range(max_repeat_count)] - labels = [x for x in range(1, max_repeat_count + 1)] - ax_main.set_xticks(label_positions, labels=labels, rotation=xlabel_rotation) - ax_main.set_xlim(right=max_repeat_count) - else: - ax_main.set_xlim(right=len(sorted_aligned_labeled_repeats)) - if sample_to_label is None: - ax_main.set_xticks([x + 0.5 for x in range(len(aligned_labeled_repeats))]) - ax_main.set_xticklabels(sorted_sample_ids, rotation=xlabel_rotation) - else: - ax_main.set_xticks([]) - - label_positions = [y + 0.5 for y in range(max_repeat_count)] - labels = [y for y in range(1, max_repeat_count + 1)] - ax_main.set_yticks(label_positions, labels=labels, rotation=ylabel_rotation) - ax_main.set_ylim(top=max_repeat_count) - - ax_main.tick_params(axis='y', which='major', labelsize=ylabel_size) - ax_main.tick_params(axis='y', which='minor', labelsize=ylabel_size) - ax_main.tick_params(axis='x', which='major', labelsize=xlabel_size) - ax_main.tick_params(axis='x', which='minor', labelsize=xlabel_size) - - if hide_xticks: - ax_main.set_xticks([]) - if hide_yticks: - ax_main.set_yticks([]) - - # Frame - if frame_on is None: # Set default - frame_on = {'top': False, 'bottom': True, 'right': False, 'left': True} - - ax_main.spines['top'].set_visible(frame_on['top']) - ax_main.spines['right'].set_visible(frame_on['right']) - ax_main.spines['bottom'].set_visible(frame_on['bottom']) - ax_main.spines['left'].set_visible(frame_on['left']) - - if output_name is not None: - if '.' not in output_name: - fig.savefig(f"{output_name}.pdf", dpi=dpi, bbox_inches='tight') - else: - fig.savefig(f"{output_name}", dpi=dpi, bbox_inches='tight') - else: - fig.savefig("test_trplot.png", dpi=dpi, bbox_inches='tight') - - if show_figure: - plt.show() - plt.close(fig) - def add_label_color_axis(self, aligned_labeled_repeats, allele_as_row, ax_main, box_line_width, sample_to_label, sorted_aligned_labeled_repeats, sorted_sample_ids, xlabel_rotation, xlabel_size, ylabel_rotation, ylabel_size): @@ -449,6 +509,7 @@ def sort_by_clustering(self, ax_main, aligned_labeled_repeats, sample_ids, symbo @staticmethod def get_symbol_to_color_map(alpha, unique_symbol_count, unique_symbols, color_palette=None, colormap=None): + """ Get a dictionary mapping symbols to colors """ if colormap is not None: cmap = colormap @@ -463,14 +524,13 @@ def get_symbol_to_color_map(alpha, unique_symbol_count, unique_symbols, color_pa (0.835, 0.369, 0), (0.8, 0.475, 0.655)]) else: - import distinctipy cmap = distinctipy.get_colors(unique_symbol_count, pastel_factor=0.9, rng=777) cmap = ListedColormap(cmap) if color_palette is not None: cmap = plt.get_cmap(color_palette) - symbol_to_color = {r: cmap(i) for i, r in enumerate(list(unique_symbols))} + symbol_to_color = {r: cmap(i) for i, r in enumerate(unique_symbols)} return symbol_to_color def set_symbol_to_color_map(self, symbol_to_color):