Skip to content

Commit

Permalink
Merge pull request #485 from aristoteleo/update_pp_tests
Browse files Browse the repository at this point in the history
Update pp tests
  • Loading branch information
Xiaojieqiu authored May 4, 2023
2 parents b8aecad + a0766d4 commit 7398a6a
Showing 1 changed file with 122 additions and 1 deletion.
123 changes: 122 additions & 1 deletion tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# from utils import *
import dynamo as dyn
from dynamo.preprocessing import Preprocessor
from dynamo.preprocessing.cell_cycle import get_cell_phase
from dynamo.preprocessing.preprocessor_utils import (
calc_mean_var_dispersion_sparse,
is_float_integer_arr,
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_pca():
assert np.linalg.norm(pca.explained_variance_ratio_[:10] - adata.uns["explained_variance_ratio_"][:10]) < 1e-1



def test_preprocessor_seurat(adata):
adata = dyn.sample_data.zebrafish()
preprocessor = dyn.pp.Preprocessor()
Expand Down Expand Up @@ -228,6 +230,121 @@ def test_is_nonnegative():
assert not is_nonnegative_integer_arr(test_mat)


def test_filter_genes_by_clusters_():
# Create test data
n_cells = 1000
n_genes = 500
data = np.random.rand(n_cells, n_genes)
layers = {
"spliced": csr_matrix(data),
"unspliced": csr_matrix(data),
}
adata = anndata.AnnData(X=data, layers=layers)

# Add cluster information
clusters = np.random.randint(low=0, high=3, size=n_cells)
adata.obs['clusters'] = clusters

# Filter genes by cluster
clu_avg_selected = dyn.pp.filter_genes_by_clusters_(adata, 'clusters')

# Check that the output is a numpy array
assert type(clu_avg_selected) == np.ndarray

# Check that the output has the correct shape
assert clu_avg_selected.shape == (n_genes,)

# Check that all genes with U and S average > min_avg_U and min_avg_S respectively are selected
U, S = adata.layers['unspliced'], adata.layers['spliced']
U_avgs = np.array([np.mean(U[clusters == i], axis=0) for i in range(3)])
S_avgs = np.array([np.mean(S[clusters == i], axis=0) for i in range(3)])
expected_clu_avg_selected = np.any((U_avgs.max(1) > 0.02) & (S_avgs.max(1) > 0.08), axis=0)
assert np.array_equal(clu_avg_selected, expected_clu_avg_selected)


def test_filter_genes_by_outliers():
# create a small test AnnData object
data = np.array([[1, 0, 3, 0], [0, 1, 1, 0], [0, 1, 2, 1], [0, 0, 0, 0], [0, 1, 1, 1], [1, 0, 0, 1]])
adata = anndata.AnnData(data)
adata.obs_names = ["cell1", "cell2", "cell3", "cell4", "cell5", "cell6"]
adata.var_names = ["gene1", "gene2", "gene3", "gene4"]

filtered_adata = dyn.pp.filter_genes_by_outliers(
adata,
min_avg_exp_s=0.5,
min_cell_s=2,
max_avg_exp=2.5,
min_count_s=2,
inplace=False,
)

# check that the filtered object contains the correct values
assert np.all(filtered_adata.values == [False, True, True, True])

# check that the original object is unchanged
assert np.all(adata.var_names.values == ["gene1", "gene2", "gene3", "gene4"])

dyn.pp.filter_genes_by_outliers(adata,
min_avg_exp_s=0.5,
min_cell_s=2,
max_avg_exp=2.5,
min_count_s=2,
inplace=True)

# check that the adata has been updated
assert adata.shape == (6, 3)
assert np.all(adata.var_names.values == ["gene2", "gene3", "gene4"])


def test_filter_cells_by_outliers():
# Create a test AnnData object with some example data
adata = anndata.AnnData(
X=np.array([[1, 0, 3], [4 ,0 ,0], [7, 8, 9], [10, 11, 12]]))
adata.var_names = ["gene1", "gene2", "gene3"]
adata.obs_names = ["cell1", "cell2", "cell3", "cell4"]

# Test the function with custom range values
dyn.pp.filter_cells_by_outliers(
adata, min_expr_genes_s=2, max_expr_genes_s=6)

assert np.array_equal(
adata.obs_names.values,
["cell1", "cell3", "cell4"],
)

# Test the function with invalid layer value
try:
dyn.pp.filter_cells_by_outliers(adata, layer="invalid_layer")
assert False, "Expected ValueError"
except ValueError:
pass

def test_get_cell_phase():
from collections import OrderedDict

# create a mock anndata object with mock data
adata = anndata.AnnData(
X=pd.DataFrame(
[[1, 2, 7, 4, 1], [4, 3, 3, 5, 16], [5, 26, 7, 18, 9], [8, 39, 1, 1, 12]],
columns=["arglu1", "dynll1", "cdca5", "cdca8", "ckap2"],
)
)

# expected output
expected_output = pd.DataFrame(
{
"G1-S": [0.52330,-0.28244,-0.38155,-0.13393],
"S": [-0.77308, 0.98018, 0.39221, -0.38089],
"G2-M": [-0.33656, -0.27547, 0.70090, 0.35216],
"M": [0.07714, -0.36019, -0.67685, 0.77044],
"M-G1": [0.50919, -0.06209, -0.03472, -0.60778],
},
)

# test the function output against the expected output
np.allclose(get_cell_phase(adata).iloc[:, :5], expected_output)


def test_gene_selection_method():
adata = dyn.sample_data.zebrafish()
dyn.pl.basic_stats(adata)
Expand Down Expand Up @@ -357,7 +474,11 @@ def test_regress_out():
# test_highest_frac_genes_plot(adata.copy())
# test_highest_frac_genes_plot_prefix_list(adata.copy())
# test_recipe_monocle_feature_selection_layer_simple0()
# test_filter_genes_by_clusters_()
# test_filter_genes_by_outliers()
# test_filter_cells_by_outliers()
# test_get_cell_phase()
# test_gene_selection_method()
test_normalize()
# test_normalize()
# test_regress_out()
pass

0 comments on commit 7398a6a

Please sign in to comment.