Skip to content

Commit

Permalink
Merge pull request #666 from broadinstitute/mw/agg_by_strata_pr_fb
Browse files Browse the repository at this point in the history
Rearrange and enforce adj_group and group_membership being on the sam…
  • Loading branch information
jkgoodrich authored Jan 16, 2024
2 parents ec53431 + 5245a61 commit 30c4d17
Showing 1 changed file with 39 additions and 41 deletions.
80 changes: 39 additions & 41 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,87 +1957,85 @@ def agg_by_strata(
'group_membership' annotation that is a list of bools to aggregate the columns by.
:param mt: Input MatrixTable.
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
specified, additional annotations are added to the output Table/MatrixTable.
The keys of the dict are the names of the annotations and the values are tuples
:param entry_agg_funcs: Dict of entry aggregation functions where the
keys of the dict are the names of the annotations and the values are tuples
of functions. The first function is used to transform the `mt` entries in some
way, and the second function is used to aggregate the output from the first
function.
:param select_fields: Optional list of row fields from `mt` to keep on the output
Table.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. If not provided, the 'group_membership'
to stratify the aggregations by. If not provided, the 'group_membership'
annotation is expected to be present on `mt`.
:return: Table or MatrixTable with allele frequencies by strata.
:return: Table with annotations of stratified aggregations.
"""
if entry_agg_funcs is None:
entry_agg_funcs = {}
if select_fields is None:
select_fields = []

n_samples = mt.count_cols()
global_expr = {}
if "adj_group" in mt.index_globals():
global_expr["adj_group"] = mt.index_globals().adj_group
logger.info("Using the 'adj_group' global annotation found on the input MT.")
raise TypeError(
"'agg_by_strata' expects a 'entry_agg_funcs' dictionary but it was not"
" supplied. Without the dictionary, no aggregations will occur."
)

if group_membership_ht is None and "group_membership" not in mt.col:
raise ValueError(
"The 'group_membership' annotation is not found in the input MatrixTable "
"and 'group_membership_ht' is not specified."
)
elif group_membership_ht is None:

if select_fields is None:
select_fields = []

if group_membership_ht is None:
logger.info(
"'group_membership_ht' is not specified, using sample stratification "
"indicated by the 'group_membership' annotation on mt."
)
n_groups = len(mt.group_membership.take(1)[0])
group_globals = mt.index_globals()
else:
logger.info(
"'group_membership_ht' is specified, using sample stratification indicated "
"by its 'group_membership' annotation."
)
group_globals = group_membership_ht.index_globals()
n_groups = len(group_membership_ht.group_membership.take(1)[0])
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)
if "adj_group" not in global_expr:
if "adj_group" in group_globals:
global_expr["adj_group"] = group_globals.adj_group
logger.info(
"Using the 'adj_group' global annotation on 'group_membership_ht'."
)
elif "freq_meta" in group_globals:
logger.info(
"The 'freq_meta' global annotation is found in "
"'group_membership_ht', using it to determine the adj filtered "
"stratification groups."
)
freq_meta = group_globals.freq_meta
global_expr["adj_group"] = freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)

if "adj_group" not in global_expr:
global_expr = {}
n_groups = len(mt.group_membership.take(1)[0])
if "adj_group" in group_globals:
global_expr["adj_group"] = group_globals.adj_group
logger.info("Using the 'adj_group' global annotation on 'group_membership_ht'.")
elif "freq_meta" in group_globals:
logger.info(
"The 'freq_meta' global annotation is found in "
"'group_membership_ht', using it to determine the adj filtered "
"stratification groups."
)
freq_meta = group_globals.freq_meta
global_expr["adj_group"] = freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)
else:
global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False)

# NOTE: Unsure if we still want this check here since the adj_group and n_groups
# always be from the same table or built within this function? Its a cheap operation
# so I'm leaning towards keeping it even though I'm not sure this is the right place
# for this check.
n_adj_group = hl.eval(hl.len(global_expr["adj_group"]))
if hl.eval(hl.len(global_expr["adj_group"])) != n_groups:
if n_adj_group != n_groups:
raise ValueError(
f"The number of elements in the 'adj_group' ({n_adj_group}) global "
"annotation does not match the number of elements in the "
f"'group_membership' annotation ({n_groups})!",
)

# Keep only the entries needed for the aggregation functions.
select_expr = {}
has_adj = False
if hl.eval(hl.any(global_expr["adj_group"])):
select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}}
has_adj = hl.eval(hl.any(global_expr["adj_group"]))
if has_adj:
select_expr["adj"] = mt.adj
has_adj = True

select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()})
mt = mt.select_entries(**select_expr)

# Convert MT to HT with a row annotation that is an array of all samples entries
Expand All @@ -2047,7 +2045,7 @@ def agg_by_strata(
# For each stratification group in group_membership, determine the indices of the
# samples that belong to that group.
global_expr["indices_by_group"] = hl.range(n_groups).map(
lambda g_i: hl.range(n_samples).filter(
lambda g_i: hl.range(mt.count_cols()).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
)
)
Expand Down

0 comments on commit 30c4d17

Please sign in to comment.