Skip to content

Commit

Permalink
Merge pull request #693 from broadinstitute/jg/validity_check_fail_pr…
Browse files Browse the repository at this point in the history
…int_fix

Fix `generic_field_check` in validity_checks.py print of failed checks
  • Loading branch information
jkgoodrich authored Apr 16, 2024
2 parents 1c95760 + 29c88d8 commit 77c9f60
Showing 1 changed file with 96 additions and 59 deletions.
155 changes: 96 additions & 59 deletions gnomad/assessment/validity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from pprint import pprint
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import hail as hl
from hail.utils.misc import new_temp_file
Expand Down Expand Up @@ -48,10 +48,10 @@ def generic_field_check(
:param ht_count: Optional number of sites within hail Table (previously computed). If not supplied, a count of sites in the Table is performed.
:return: None
"""
if (n_fail is None and cond_expr is None) or (n_fail and cond_expr):
raise ValueError("One and only one of n_fail or cond_expr must be defined!")
if n_fail is None and cond_expr is None:
raise ValueError("At least one of n_fail or cond_expr must be defined!")

if cond_expr:
if n_fail is None and cond_expr is not None:
n_fail = ht.filter(cond_expr).count()

if show_percent_sites and (ht_count is None):
Expand All @@ -63,7 +63,9 @@ def generic_field_check(
logger.info(
"Percentage of sites that fail: %.2f %%", 100 * (n_fail / ht_count)
)
ht.select(**display_fields).show()
if cond_expr is not None:
ht = ht.select(_fail=cond_expr, **display_fields)
ht.filter(ht._fail).drop("_fail").show()
else:
logger.info("PASSED %s check", check_description)
if verbose:
Expand Down Expand Up @@ -197,9 +199,8 @@ def make_group_sum_expr_dict(
check_field_left = f"{subset}{metric}{delimiter}{group}"
check_field_right = f"sum{delimiter}{check_field_left}{delimiter}{sum_group}"
field_check_expr[f"{check_field_left} = {check_field_right}"] = {
"expr": hl.agg.count_where(
t.info[check_field_left] != annot_dict[check_field_right]
),
"expr": t.info[check_field_left] != annot_dict[check_field_right],
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{
check_field_left: t.info[check_field_left],
Expand Down Expand Up @@ -310,46 +311,34 @@ def _filter_agg_order(
logger.info(
"Checking distributions of filtered variants amongst variant filters..."
)
_filter_agg_order(t, {"is_filtered": t.is_filtered})

logger.info("Checking distributions of variant type amongst variant filters...")
_filter_agg_order(t, {"allele_type": t.info.allele_type})

logger.info(
"Checking distributions of variant type and region type amongst variant"
" filters..."
)
_filter_agg_order(
t,
{
"allele_type": t.info.allele_type,
"in_problematic_region": t.in_problematic_region,
},
n_rows,
n_cols,
)
_filter_agg_order(t, {"is_filtered": t.is_filtered}, n_rows, n_cols)

add_agg_expr = {}
if "allele_type" in t.info:
logger.info("Checking distributions of variant type amongst variant filters...")
add_agg_expr["allele_type"] = t.info.allele_type
_filter_agg_order(t, add_agg_expr, n_rows, n_cols)

if "in_problematic_region" in t.row:
logger.info(
"Checking distributions of variant type and region type amongst variant"
" filters..."
)
add_agg_expr["in_problematic_region"] = t.in_problematic_region
_filter_agg_order(t, add_agg_expr, n_rows, n_cols)

logger.info(
"Checking distributions of variant type, region type, and number of alt alleles"
" amongst variant filters..."
)
_filter_agg_order(
t,
{
"allele_type": t.info.allele_type,
"in_problematic_region": t.in_problematic_region,
"n_alt_alleles": t.info.n_alt_alleles,
},
n_rows,
n_cols,
)
if "n_alt_alleles" in t.info:
logger.info(
"Checking distributions of variant type, region type, and number of alt alleles"
" amongst variant filters..."
)
add_agg_expr["n_alt_alleles"] = t.info.n_alt_alleles
_filter_agg_order(t, add_agg_expr, n_rows, n_cols)


def generic_field_check_loop(
ht: hl.Table,
field_check_expr: Dict[
str, Dict[str, Union[hl.expr.Int64Expression, hl.expr.StructExpression]]
],
field_check_expr: Dict[str, Dict[str, Any]],
verbose: bool,
show_percent_sites: bool = False,
ht_count: int = None,
Expand All @@ -367,14 +356,15 @@ def generic_field_check_loop(
:return: None
"""
ht_field_check_counts = ht.aggregate(
hl.struct(**{k: v["expr"] for k, v in field_check_expr.items()})
hl.struct(**{k: v["agg_func"](v["expr"]) for k, v in field_check_expr.items()})
)
for check_description, n_fail in ht_field_check_counts.items():
generic_field_check(
ht,
check_description=check_description,
n_fail=n_fail,
display_fields=field_check_expr[check_description]["display_fields"],
cond_expr=field_check_expr[check_description]["expr"],
verbose=verbose,
show_percent_sites=show_percent_sites,
ht_count=ht_count,
Expand Down Expand Up @@ -433,9 +423,8 @@ def compare_subset_freqs(
)

field_check_expr[f"{check_field_left} != {check_field_right}"] = {
"expr": hl.agg.count_where(
t.info[check_field_left] == t.info[check_field_right]
),
"expr": t.info[check_field_left] == t.info[check_field_right],
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{
check_field_left: t.info[check_field_left],
Expand Down Expand Up @@ -591,18 +580,68 @@ def check_raw_and_adj_callstats(
t = t.rows() if isinstance(t, hl.MatrixTable) else t

field_check_expr = {}

for group in ["raw", "adj"]:
# Check AC and nhomalt missing if AN is missing and defined if AN is defined.
for subfield in ["AC", "nhomalt"]:
check_field = f"{subfield}{delimiter}{group}"
an_field = f"AN{delimiter}{group}"
field_check_expr[
f"{check_field} defined when AN defined and missing when AN missing"
] = {
"expr": hl.if_else(
hl.is_missing(t.info[an_field]),
hl.is_defined(t.info[check_field]),
hl.is_missing(t.info[check_field]),
),
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{an_field: t.info[an_field], check_field: t.info[check_field]}
),
}

# Check AF missing if AN is missing and defined if AN is defined and > 0.
check_field = f"AF{delimiter}{group}"
an_field = f"AN{delimiter}{group}"
field_check_expr[
f"{check_field} defined when AN defined (and > 0) and missing when AN missing"
] = {
"expr": hl.if_else(
hl.is_missing(t.info[an_field]),
hl.is_defined(t.info[check_field]),
(t.info[an_field] > 0) & hl.is_missing(t.info[check_field]),
),
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{an_field: t.info[an_field], check_field: t.info[check_field]}
),
}

# Check raw and adj AF missing if AN is 0.
check_field = f"AF{delimiter}{group}"
an_field = f"AN{delimiter}{group}"
field_check_expr[f"{check_field} missing when AN 0"] = {
"expr": (t.info[an_field] == 0) & hl.is_defined(t.info[check_field]),
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{an_field: t.info[an_field], check_field: t.info[check_field]}
),
}

for subfield in ["AC", "AF"]:
# Check raw AC, AF > 0
check_field = f"{subfield}{delimiter}raw"

field_check_expr[f"{check_field} > 0"] = {
"expr": hl.agg.count_where(t.info[check_field] <= 0),
"expr": t.info[check_field] <= 0,
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(**{check_field: t.info[check_field]}),
}

# Check adj AC, AF > 0
check_field = f"{subfield}{delimiter}adj"
field_check_expr[f"{check_field} >= 0"] = {
"expr": hl.agg.count_where(t.info[check_field] < 0),
"expr": t.info[check_field] < 0,
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{check_field: t.info[check_field], "filters": t.filters}
),
Expand All @@ -614,9 +653,8 @@ def check_raw_and_adj_callstats(
check_field_right = f"{subfield}{delimiter}adj"

field_check_expr[f"{check_field_left} >= {check_field_right}"] = {
"expr": hl.agg.count_where(
t.info[check_field_left] < t.info[check_field_right]
),
"expr": t.info[check_field_left] < t.info[check_field_right],
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{
check_field_left: t.info[check_field_left],
Expand All @@ -638,9 +676,8 @@ def check_raw_and_adj_callstats(
check_field_right = f"{field_check_label}adj"

field_check_expr[f"{check_field_left} >= {check_field_right}"] = {
"expr": hl.agg.count_where(
t.info[check_field_left] < t.info[check_field_right]
),
"expr": t.info[check_field_left] < t.info[check_field_right],
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{
check_field_left: t.info[check_field_left],
Expand Down Expand Up @@ -715,9 +752,9 @@ def check_sex_chr_metrics(
check_field_left = f"{metric}"
check_field_right = f"{standard_field}"
field_check_expr[f"{check_field_left} == {check_field_right}"] = {
"expr": hl.agg.count_where(
t_xnonpar.info[check_field_left] != t_xnonpar.info[check_field_right]
),
"expr": t_xnonpar.info[check_field_left]
!= t_xnonpar.info[check_field_right],
"agg_func": hl.agg.count_where,
"display_fields": hl.struct(
**{
check_field_left: t_xnonpar.info[check_field_left],
Expand Down

0 comments on commit 77c9f60

Please sign in to comment.