-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #473 from aristoteleo/external_to_preprocess
External to preprocess
- Loading branch information
Showing
10 changed files
with
597 additions
and
221 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from typing import Optional | ||
|
||
from anndata import AnnData | ||
from matplotlib.axes import Axes | ||
from matplotlib.figure import Figure | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
|
||
def sctransform_plot_fit( | ||
adata: AnnData, | ||
xaxis: str = "gmean", | ||
fig: Optional[Figure] = None, | ||
) -> Figure: | ||
"""Plot the fitting of model parameters in sctransform. | ||
Args: | ||
adata: annotated data matrix after sctransform. | ||
xaxis: the gene expression metric is plotted on the x-axis. | ||
fig: Matplotlib figure object to use for the plot. If not provided, a new figure is created. | ||
Returns: | ||
The matplotlib figure object containing the plot. | ||
""" | ||
if fig is None: | ||
fig = plt.figure(figsize=(12, 3)) | ||
gene_names = adata.var['genes_step1_sct'][ | ||
~adata.var['genes_step1_sct'].isna()].index | ||
|
||
genes_log10_mean = adata.var["log10_gmean_sct"] | ||
genes_log_gmean = genes_log10_mean[~genes_log10_mean.isna()] | ||
|
||
model_params_fit = pd.concat( | ||
[adata.var["log_umi_sct"], adata.var["Intercept_sct"], adata.var["theta_sct"]], axis=1) | ||
model_params = pd.concat( | ||
[adata.var["log_umi_step1_sct"], adata.var["Intercept_step1_sct"], adata.var["model_pars_theta_step1"]], | ||
axis=1) | ||
model_params_fit = model_params_fit.rename( | ||
columns={"log_umi_sct": "log_umi", "Intercept_sct": "Intercept", "theta_sct": "theta"}) | ||
model_params = model_params.rename( | ||
columns={"log_umi_step1_sct": "log_umi", | ||
"Intercept_step1_sct": "Intercept", | ||
"model_pars_theta_step1": "theta"}) | ||
|
||
model_params = model_params.loc[gene_names] | ||
|
||
total_params = model_params_fit.shape[1] | ||
|
||
for index, column in enumerate(model_params_fit.columns): | ||
ax = fig.add_subplot(1, total_params, index + 1) | ||
model_param_col = model_params[column] | ||
|
||
# model_param_outliers = is_outlier(model_param_col) | ||
if column != "theta": | ||
ax.scatter( | ||
genes_log_gmean, # [~model_param_outliers], | ||
model_param_col, # [~model_param_outliers], | ||
s=1, | ||
label="single gene estimate", | ||
color="#2b8cbe", | ||
) | ||
ax.scatter( | ||
genes_log10_mean, | ||
model_params_fit[column], | ||
s=2, | ||
label="regularized", | ||
color="#de2d26", | ||
) | ||
ax.set_ylabel(column) | ||
else: | ||
ax.scatter( | ||
genes_log_gmean, # [~model_param_outliers], | ||
np.log10(model_param_col), # [~model_param_outliers], | ||
s=1, | ||
label="single gene estimate", | ||
color="#2b8cbe", | ||
) | ||
ax.scatter( | ||
genes_log10_mean, | ||
np.log10(model_params_fit[column]), | ||
s=2, | ||
label="regularized", | ||
color="#de2d26", | ||
) | ||
ax.set_ylabel("log10(" + column + ")") | ||
if column == "od_factor": | ||
ax.set_ylabel("log10(od_factor)") | ||
|
||
ax.set_xlabel("log10(gene_{})".format(xaxis)) | ||
ax.set_title(column) | ||
ax.legend(frameon=False) | ||
_ = fig.tight_layout() | ||
return fig | ||
|
||
def plot_residual_var( | ||
adata: AnnData, | ||
topngenes: int = 10, | ||
label_genes: bool = True, | ||
ax: Optional[Axes] = None, | ||
) -> Figure: | ||
"""Plot the relationship between the mean and variance of gene expression across cells, highlighting the genes with | ||
the highest residual variance. | ||
Args: | ||
adata: annotated data matrix after sctransform. | ||
topngenes: the number of genes with the highest residual variance to highlight in the plot. | ||
label_genes: whether to label the highlighted genes in the plot. If `topngenes` is large, labeling genes may | ||
lead to plotting error because of the space limitation. | ||
ax: the axes on which to draw the plot. If None, a new figure and axes are created. | ||
Returns: | ||
The Figure object if `ax` is not given, else None. | ||
""" | ||
def vars(a, axis=None): | ||
"""Helper function to calculate variance of sparse matrix by equation: var = mean(a**2) - mean(a)**2""" | ||
a_squared = a.copy() | ||
a_squared.data **= 2 | ||
return a_squared.mean(axis) - np.square(a.mean(axis)) | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots(figsize=(8, 5)) | ||
else: | ||
fig = None | ||
|
||
gene_attr = pd.DataFrame(adata.var['log10_gmean_sct']) | ||
# gene_attr = gene_attr.loc[gene_names] | ||
gene_attr["var"] = vars(adata.X, axis=0).tolist()[0] | ||
gene_attr["mean"] = adata.X.mean(axis=0).tolist()[0] | ||
gene_attr_sorted = gene_attr.sort_values( | ||
"var", ascending=False | ||
).reset_index() | ||
topn = gene_attr_sorted.iloc[:topngenes] | ||
gene_attr = gene_attr_sorted.iloc[topngenes:] | ||
ax.set_xscale("log") | ||
|
||
ax.scatter( | ||
gene_attr["mean"], gene_attr["var"], s=1.5, color="black" | ||
) | ||
ax.scatter(topn["mean"], topn["var"], s=1.5, color="deeppink") | ||
ax.axhline(1, linestyle="dashed", color="red") | ||
ax.set_xlabel("mean") | ||
ax.set_ylabel("var") | ||
if label_genes: | ||
texts = [ | ||
plt.text(row["mean"], row["var"], row["index"]) | ||
for index, row in topn.iterrows() | ||
] | ||
fig.tight_layout() | ||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .integration import harmony_debatch, integrate | ||
from .pearson_residual_recipe import ( | ||
normalize_layers_pearson_residuals, | ||
select_genes_by_pearson_residuals, | ||
) | ||
from .sctransform import sctransform | ||
|
||
__all__ = [ | ||
"normalize_layers_pearson_residuals", | ||
"sctransform", | ||
"select_genes_by_pearson_residuals", | ||
"harmony_debatch", | ||
"integrate", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
from anndata import AnnData | ||
from scipy.sparse import csr_matrix, isspmatrix | ||
|
||
# Convert sparse matrix to dense matrix. | ||
to_dense_matrix = lambda X: np.array(X.todense()) if isspmatrix(X) else np.asarray(X) | ||
|
||
def integrate( | ||
adatas: List[AnnData], | ||
batch_key: str = "slices", | ||
fill_value: Union[int, float] = 0, | ||
) -> AnnData: | ||
"""Concatenating all anndata objects. | ||
Args: | ||
adatas: AnnData matrices to concatenate with. | ||
batch_key: the key to add the batch annotation to :attr:`obs`. | ||
fill_value: Scalar value to fill newly missing values in arrays with. | ||
Returns: | ||
The concatenated AnnData, where adata.obs[batch_key] stores a categorical variable labeling the batch. | ||
""" | ||
|
||
batch_ca = [adata.obs[batch_key][0] for adata in adatas] | ||
|
||
# Merge the obsm, varm and uns data of all anndata objcets separately. | ||
obsm_dict, varm_dict, uns_dict = {}, {}, {} | ||
obsm_keys, varm_keys, uns_keys = [], [], [] | ||
for adata in adatas: | ||
obsm_keys.extend(list(adata.obsm.keys())) | ||
varm_keys.extend(list(adata.varm.keys())) | ||
uns_keys.extend(list(adata.uns_keys())) | ||
|
||
obsm_keys, varm_keys, uns_keys = list(set(obsm_keys)), list(set(varm_keys)), list(set(uns_keys)) | ||
n_obsm_keys, n_varm_keys, n_uns_keys = len(obsm_keys), len(varm_keys), len(uns_keys) | ||
|
||
if n_obsm_keys > 0: | ||
for key in obsm_keys: | ||
obsm_dict[key] = np.concatenate([to_dense_matrix(adata.obsm[key]) for adata in adatas], axis=0) | ||
if n_varm_keys > 0: | ||
for key in varm_keys: | ||
varm_dict[key] = np.concatenate([to_dense_matrix(adata.varm[key]) for adata in adatas], axis=0) | ||
if n_uns_keys > 0: | ||
for key in uns_keys: | ||
if "__type" in uns_keys and key == "__type": | ||
uns_dict["__type"] = adatas[0].uns["__type"] | ||
else: | ||
uns_dict[key] = { | ||
ca: adata.uns[key] if key in adata.uns_keys() else None for ca, adata in zip(batch_ca, adatas) | ||
} | ||
|
||
# Delete obsm, varm and uns data. | ||
for adata in adatas: | ||
del adata.obsm, adata.varm, adata.uns | ||
|
||
# Concatenating obs and var data which will ignore the uns, obsm, varm attributes. | ||
integrated_adata = adatas[0].concatenate( | ||
*adatas[1:], | ||
batch_key=batch_key, | ||
batch_categories=batch_ca, | ||
join="outer", | ||
fill_value=fill_value, | ||
uns_merge=None, | ||
) | ||
|
||
# Add Concatenated obsm data and varm data to integrated anndata object. | ||
if n_obsm_keys > 0: | ||
for key, value in obsm_dict.items(): | ||
integrated_adata.obsm[key] = value | ||
if n_varm_keys > 0: | ||
for key, value in varm_dict.items(): | ||
integrated_adata.varm[key] = value | ||
if n_uns_keys > 0: | ||
for key, value in uns_dict.items(): | ||
integrated_adata.uns[key] = value | ||
|
||
return integrated_adata | ||
|
||
def harmony_debatch( | ||
adata: AnnData, | ||
key: str, | ||
basis: str = "X_pca", | ||
adjusted_basis: str = "X_pca_harmony", | ||
max_iter_harmony: int = 10, | ||
copy: bool = False, | ||
) -> Optional[AnnData]: | ||
"""Use harmonypy [Korunsky19]_ to remove batch effects. | ||
This function should be run after performing PCA but before computing the neighbor graph. Original Code Repository | ||
is /~https://github.com/slowkow/harmonypy. Interesting example: https://slowkow.com/notes/harmony-animation/ | ||
Args: | ||
adata: An Anndata object. | ||
key: The name of the column in ``adata.obs`` that differentiates among experiments/batches. | ||
basis: The name of the field in ``adata.obsm`` where the PCA table is stored. | ||
adjusted_basis: The name of the field in ``adata.obsm`` where the adjusted PCA table will be stored after | ||
running this function. | ||
max_iter_harmony: Maximum number of rounds to run Harmony. One round of Harmony involves one clustering and one | ||
correction step. | ||
copy: Whether to copy `adata` or modify it inplace. | ||
Returns: | ||
Updates adata with the field ``adata.obsm[adjusted_basis]``, containing principal components adjusted by | ||
Harmony. | ||
""" | ||
try: | ||
import harmonypy | ||
except ImportError: | ||
raise ImportError("\nplease install harmonypy:\n\n\tpip install harmonypy") | ||
|
||
adata = adata.copy() if copy else adata | ||
|
||
# Convert sparse matrix to dense matrix. | ||
matrix = to_dense_matrix(adata.obsm[basis]) | ||
|
||
# Use Harmony to adjust the PCs. | ||
harmony_out = harmonypy.run_harmony(matrix, adata.obs, key, max_iter_harmony=max_iter_harmony) | ||
adjusted_matrix = harmony_out.Z_corr.T | ||
|
||
# Convert dense matrix to sparse matrix. | ||
if isspmatrix(adata.obsm[basis]): | ||
adjusted_matrix = csr_matrix(adjusted_matrix) | ||
|
||
adata.obsm[adjusted_basis] = adjusted_matrix | ||
|
||
return adata if copy else None |
Oops, something went wrong.