Skip to content

Commit

Permalink
Merge pull request #662 from broadinstitute/jg/generalize_stat_comput…
Browse files Browse the repository at this point in the history
…e_at_all_ref_sites

Add `compute_stats_per_ref_site` to generalize computation of aggregate stats at all sites in a reference Table
  • Loading branch information
jkgoodrich authored Jan 19, 2024
2 parents 1f9ba18 + 988044c commit c310157
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 71 deletions.
8 changes: 6 additions & 2 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,9 +2069,13 @@ def _agg_by_group(
:param ann_expr: Expression to aggregate by group.
:return: Aggregated array expression.
"""
f = lambda i, adj: agg_func(ann_expr[i])
f_no_adj = lambda i, *args: agg_func(ann_expr[i])
if has_adj:
f = lambda i, adj: hl.if_else(adj, hl.agg.filter(ht.adj[i], f), f)
f = lambda i, adj: hl.if_else(
adj, hl.agg.filter(ht.adj[i], f_no_adj(i)), f_no_adj(i)
)
else:
f = f_no_adj

return hl.map(
lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)),
Expand Down
280 changes: 211 additions & 69 deletions gnomad/utils/sparse_mt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# noqa: D100

import logging
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import hail as hl

from gnomad.utils.annotations import (
agg_by_strata,
annotate_adj,
fs_from_sb,
generate_freq_group_membership_array,
get_adj_expr,
Expand Down Expand Up @@ -991,67 +992,91 @@ def densify_all_reference_sites(
return mt


def compute_coverage_stats(
def compute_stats_per_ref_site(
mtds: Union[hl.MatrixTable, hl.vds.VariantDataset],
reference_ht: hl.Table,
entry_agg_funcs: Dict[str, Tuple[Callable, Callable]],
row_key_fields: Union[Tuple[str], List[str]] = ("locus",),
interval_ht: Optional[hl.Table] = None,
coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100],
row_key_fields: List[str] = ["locus"],
entry_keep_fields: Union[Tuple[str], List[str], Set[str]] = None,
strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Compute coverage statistics for every base of the `reference_ht` provided.
Compute stats per site in a reference Table.
The following coverage stats are calculated:
- mean
- median
- total DP
- fraction of samples with coverage above X, for each x in `coverage_over_x_bins`
The `reference_ht` is a Table that contains a row for each locus coverage that should be
computed on. It needs to be keyed by `locus`. The `reference_ht` can e.g. be
created using `get_reference_ht`.
:param mtds: Input sparse MT or VDS
:param reference_ht: Input reference HT
:param interval_ht: Optional Table containing intervals to filter to
:param coverage_over_x_bins: List of boundaries for computing samples over X
:param row_key_fields: List of row key fields to use for joining `mtds` with
`reference_ht`
:param strata_expr: Optional list of dicts containing expressions to stratify the
coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be
specified.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. Only one of `group_membership_ht` or
`strata_expr` can be specified.
:return: Table with per-base coverage stats.
:param mtds: Input sparse Matrix Table or VariantDataset.
:param reference_ht: Table of reference sites.
:param entry_agg_funcs: Dict of entry aggregation functions to perform on the
VariantDataset/MatrixTable. The keys of the dict are the names of the
annotations and the values are tuples of functions. The first function is used
to transform the `mt` entries in some way, and the second function is used to
aggregate the output from the first function.
:param row_key_fields: Fields to use as row key. Defaults to locus.
:param interval_ht: Optional table of intervals to filter to.
:param entry_keep_fields: Fields to keep in entries before performing the
densification in `densify_all_reference_sites`. Should include any fields
needed for the functions in `entry_agg_funcs`. By default, only GT or LGT is
kept.
:param strata_expr: Optional list of dicts of expressions to stratify by.
:param group_membership_ht: Optional Table of group membership annotations.
:return: Table of stats per site.
"""
is_vds = isinstance(mtds, hl.vds.VariantDataset)
if is_vds:
mt = mtds.variant_data
else:
mt = mtds

# Determine the genotype field.
gt_field = set(mt.entry) & {"GT", "LGT"}
if len(gt_field) == 0:
raise ValueError("No genotype field found in entry fields.")

gt_field = gt_field.pop()

if group_membership_ht is not None and strata_expr is not None:
raise ValueError(
"Only one of 'group_membership_ht' or 'strata_expr' can be specified."
)

# Determine if the adj annotation is needed. It is only needed if "adj_groups" is
# in the globals of the group_membership_ht and any entry is True, or "freq_meta"
# is in the globals of the group_membership_ht and any entry has "group" == "adj".
g = {} if group_membership_ht is None else group_membership_ht.globals
adj = hl.eval(
hl.any(g.get("adj_groups", hl.empty_array("bool")))
| hl.any(
g.get("freq_meta", hl.empty_array("dict<str, str>")).map(
lambda x: x.get("group", "NA") == "adj"
)
)
)

# Determine the entry fields on mt that should be densified.
# "GT" or "LGT" is required for the genotype.
# If the adj annotation is needed then "adj" must be present on mt, or AD/LAD, DP,
# and GQ must be present.
en = set(mt.entry)
gt_field = en & {"GT"} or en & {"LGT"}
ad_field = en & {"AD"} or en & {"LAD"}
adj_fields = en & {"adj"} or ({"DP", "GQ"} | ad_field) if adj else set([])

if not gt_field:
raise ValueError("No genotype field found in entry fields!")

if adj and not adj_fields.issubset(en):
raise ValueError(
"No 'adj' found in entry fields, and one of AD/LAD, DP, and GQ is missing "
"so adj can't be computed!"
)

entry_keep_fields = set(entry_keep_fields or set([])) | gt_field | adj_fields

# Initialize no_strata and default strata_expr if neither group_membership_ht nor
# strata_expr is provided.
no_strata = group_membership_ht is None and strata_expr is None
if no_strata:
strata_expr = {}

if group_membership_ht is None:
logger.warning(
"'group_membership_ht' is not specified, no stats are adj filtered."
)

# Annotate the MT cols with each of the expressions in strata_expr and redefine
# strata_expr based on the column HT with added annotations.
ht = mt.annotate_cols(
Expand Down Expand Up @@ -1084,12 +1109,6 @@ def compute_coverage_stats(
)
)

n_samples = group_membership_ht.count()
sample_counts = group_membership_ht.index_globals().freq_meta_sample_count

logger.info("Computing coverage stats on %d samples.", n_samples)

entry_keep_fields = set(mt.entry) & {gt_field, "DP"}
if is_vds:
rmt = mtds.reference_data
mtds = hl.vds.VariantDataset(
Expand All @@ -1105,7 +1124,87 @@ def compute_coverage_stats(
entry_keep_fields=entry_keep_fields,
)

# Compute coverage stats.
# Annotate with adj if needed.
if adj and "adj" not in mt.entry:
mt = annotate_adj(mt)

ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht)
ht = ht.checkpoint(hl.utils.new_temp_file("agg_stats", "ht"))

# Drop no longer needed fields
current_keys = list(ht.key)
ht = ht.key_by(*row_key_fields).select_globals()
ht = ht.drop(*[k for k in current_keys if k not in row_key_fields])

group_globals = group_membership_ht.index_globals()
global_expr = {}
if no_strata:
# If there was no stratification, move aggregated annotations to the top
# level.
ht = ht.select(**{ann: ht[ann][0] for ann in entry_agg_funcs})
global_expr["sample_count"] = group_globals.freq_meta_sample_count[0]
else:
# If there was stratification, add the metadata and sample count info for the
# stratification to the globals.
global_expr["strata_meta"] = group_globals.freq_meta
global_expr["strata_sample_count"] = group_globals.freq_meta_sample_count

ht = ht.annotate_globals(**global_expr)

return ht


def compute_coverage_stats(
mtds: Union[hl.MatrixTable, hl.vds.VariantDataset],
reference_ht: hl.Table,
interval_ht: Optional[hl.Table] = None,
coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100],
row_key_fields: List[str] = ["locus"],
strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Compute coverage statistics for every base of the `reference_ht` provided.
The following coverage stats are calculated:
- mean
- median
- total DP
- fraction of samples with coverage above X, for each x in `coverage_over_x_bins`
The `reference_ht` is a Table that contains a row for each locus coverage that should be
computed on. It needs to be keyed by `locus`. The `reference_ht` can e.g. be
created using `get_reference_ht`.
:param mtds: Input sparse MT or VDS
:param reference_ht: Input reference HT
:param interval_ht: Optional Table containing intervals to filter to
:param coverage_over_x_bins: List of boundaries for computing samples over X
:param row_key_fields: List of row key fields to use for joining `mtds` with
`reference_ht`
:param strata_expr: Optional list of dicts containing expressions to stratify the
coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be
specified.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. Only one of `group_membership_ht` or
`strata_expr` can be specified.
:return: Table with per-base coverage stats.
"""
is_vds = isinstance(mtds, hl.vds.VariantDataset)
if is_vds:
mt = mtds.variant_data
else:
mt = mtds

# Determine the genotype field.
en = set(mt.entry)
gt_field = en & {"GT"} or en & {"LGT"}
if not gt_field:
raise ValueError("No genotype field found in entry fields!")

gt_field = gt_field.pop()

# Add function to compute coverage stats.
cov_bins = sorted(coverage_over_x_bins)
rev_cov_bins = list(reversed(cov_bins))
max_cov_bin = cov_bins[-1]
Expand All @@ -1123,8 +1222,17 @@ def compute_coverage_stats(
),
)
}
ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht)
ht = ht.checkpoint(hl.utils.new_temp_file("coverage_stats", "ht"))

ht = compute_stats_per_ref_site(
mtds,
reference_ht,
entry_agg_funcs,
row_key_fields=row_key_fields,
interval_ht=interval_ht,
entry_keep_fields=[gt_field, "DP"],
strata_expr=strata_expr,
group_membership_ht=group_membership_ht,
)

# This expression aggregates the DP counter in reverse order of the cov_bins and
# computes the cumulative sum over them. It needs to be in reverse order because we
Expand Down Expand Up @@ -1152,37 +1260,71 @@ def _cov_stats(

return cov_stat.annotate(**bin_expr).drop("coverage_counter")

ht = ht.annotate(
coverage_stats=hl.map(
lambda c, n: _cov_stats(c, n),
ht.coverage_stats,
sample_counts,
ht_globals = ht.index_globals()
if isinstance(ht.coverage_stats, hl.expr.ArrayExpression):
ht = ht.select_globals(
coverage_stats_meta=ht_globals.strata_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
),
coverage_stats_meta_sample_count=ht_globals.strata_sample_count,
)
)
current_keys = list(ht.key)
ht = ht.key_by(*row_key_fields).select_globals()
ht = ht.drop(*[k for k in current_keys if k not in row_key_fields])
cov_stats_expr = {
"coverage_stats": hl.map(
lambda c, n: _cov_stats(c, n),
ht.coverage_stats,
ht_globals.strata_sample_count,
)
}
else:
cov_stats_expr = _cov_stats(ht.coverage_stats, ht_globals.sample_count)

group_globals = group_membership_ht.index_globals()
global_expr = {}
if no_strata:
# If there was no stratification, move coverage_stats annotations to the top
# level.
ht = ht.select(**{k: ht.coverage_stats[0][k] for k in ht.coverage_stats[0]})
global_expr["sample_count"] = group_globals.freq_meta_sample_count[0]
ht = ht.transmute(**cov_stats_expr)

return ht


def get_allele_number_agg_func(gt_field: str = "GT") -> Tuple[Callable, Callable]:
"""
Get a transformation and aggregation function for computing the allele number.
Can be used as an entry aggregation function in `compute_stats_per_ref_site`.
:param gt_field: Genotype field to use for computing the allele number.
:return: Tuple of functions to transform and aggregate the allele number.
"""
return lambda t: t[gt_field].ploidy, hl.agg.sum


def compute_allele_number_per_ref_site(
mtds: Union[hl.MatrixTable, hl.vds.VariantDataset],
reference_ht: hl.Table,
**kwargs,
) -> hl.Table:
"""
Compute the allele number per reference site.
:param mtds: Input sparse Matrix Table or VariantDataset.
:param reference_ht: Table of reference sites.
:param kwargs: Keyword arguments to pass to `compute_stats_per_ref_site`.
:return: Table of allele number per reference site.
"""
if isinstance(mtds, hl.vds.VariantDataset):
mt = mtds.variant_data
else:
# If there was stratification, add the metadata and sample count info for the
# stratification to the globals.
global_expr["coverage_stats_meta"] = group_globals.freq_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
)
global_expr["coverage_stats_meta_sample_count"] = (
group_globals.freq_meta_sample_count
mt = mtds

# Determine the genotype field.
en = set(mt.entry)
gt_field = en & {"GT"} or en & {"LGT"}
if not gt_field:
raise ValueError(
"No genotype field found in entry fields, needed for ploidy calculation!"
)

ht = ht.annotate_globals(**global_expr)
# Use ploidy to determine the number of alleles for each sample at each site.
entry_agg_funcs = {"AN": get_allele_number_agg_func(gt_field.pop())}

return ht
return compute_stats_per_ref_site(mtds, reference_ht, entry_agg_funcs, **kwargs)


def filter_ref_blocks(
Expand Down

0 comments on commit c310157

Please sign in to comment.