Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified subset_samples_and_variants() #421

Merged
merged 14 commits into from
Dec 6, 2021
41 changes: 26 additions & 15 deletions gnomad/utils/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,31 @@ def add_filters_expr(


def subset_samples_and_variants(
mt: hl.MatrixTable,
mtds: Union[hl.MatrixTable, hl.vds.VariantDataset],
sample_path: str,
header: bool = True,
table_key: str = "s",
sparse: bool = False,
gt_expr: str = "GT",
) -> hl.MatrixTable:
) -> Union[hl.MatrixTable, hl.vds.VariantDataset]:
"""
Subset the MatrixTable to the provided list of samples and their variants.
Subset the MatrixTable or VariantDataset to the provided list of samples and their variants.

:param mt: Input MatrixTable
:param mtds: Input MatrixTable or VariantDataset
:param sample_path: Path to a file with list of samples
:param header: Whether file with samples has a header. Default is True
:param table_key: Key to sample Table. Default is "s"
:param sparse: Whether the MatrixTable is sparse. Default is False
:param gt_expr: Name of field in MatrixTable containing genotype expression. Default is "GT"
:return: MatrixTable subsetted to specified samples and their variants
:return: MatrixTable or VariantDataset subsetted to specified samples and their variants
"""
sample_ht = hl.import_table(sample_path, no_header=not header, key=table_key)
sample_count = sample_ht.count()
is_vds = isinstance(mtds, hl.vds.VariantDataset)
if is_vds:
mt = mtds.variant_data
else:
wlu04 marked this conversation as resolved.
Show resolved Hide resolved
mt = mtds
missing_ht = sample_ht.anti_join(mt.cols())
missing_ht_count = missing_ht.count()
full_count = mt.count_cols()
Expand All @@ -253,24 +258,30 @@ def subset_samples_and_variants(
missing_samples = missing_ht.s.collect()
raise DataException(
f"Only {sample_count - missing_ht_count} out of {sample_count} "
"subsetting-table IDs matched IDs in the MT.\n"
f"subsetting-table IDs matched IDs in the {'VariantDataset' if is_vds else 'MatrixTable'}.\n"
f"IDs that aren't in the MT: {missing_samples}\n"
)

mt = mt.semi_join_cols(sample_ht)
if sparse:
mt = mt.filter_rows(
hl.agg.any(mt[gt_expr].is_non_ref() | hl.is_defined(mt.END))
)
if is_vds:
mtds = hl.vds.filter_samples(mtds, sample_ht, keep=True)
n_cols = mtds.variant_data.count_cols()
else:
mt = mt.filter_rows(hl.agg.any(mt[gt_expr].is_non_ref()))
mtds = mtds.semi_join_cols(sample_ht)
if sparse:
mtds = mtds.filter_rows(
hl.agg.any(mtds[gt_expr].is_non_ref() | hl.is_defined(mtds.END))
)
else:
mtds = mtds.filter_rows(hl.agg.any(mtds[gt_expr].is_non_ref()))
n_cols = mtds.count_cols()

logger.info(
"Finished subsetting samples. Kept %d out of %d samples in MT",
mt.count_cols(),
"Finished subsetting samples. Kept %d out of %d samples in %s",
n_cols,
full_count,
"VariantDataset" if is_vds else "MatrixTable",
)
return mt
return mtds


def filter_to_clinvar_pathogenic(
Expand Down