From 5245a614bdefe3e0e5447cd7f1c3a2973fe69c83 Mon Sep 17 00:00:00 2001 From: Mike Wilson Date: Thu, 11 Jan 2024 15:27:01 -0500 Subject: [PATCH] Rearrange and enforce adj_group and group_membership being on the same HT/MT --- gnomad/utils/annotations.py | 80 ++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index ec40287ea..a312f6684 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1957,73 +1957,73 @@ def agg_by_strata( 'group_membership' annotation that is a list of bools to aggregate the columns by. :param mt: Input MatrixTable. - :param entry_agg_funcs: Optional dict of entry aggregation functions. When - specified, additional annotations are added to the output Table/MatrixTable. - The keys of the dict are the names of the annotations and the values are tuples + :param entry_agg_funcs: Dict of entry aggregation functions where the + keys of the dict are the names of the annotations and the values are tuples of functions. The first function is used to transform the `mt` entries in some way, and the second function is used to aggregate the output from the first function. :param select_fields: Optional list of row fields from `mt` to keep on the output Table. :param group_membership_ht: Optional Table containing group membership annotations - to stratify the coverage stats by. If not provided, the 'group_membership' + to stratify the aggregations by. If not provided, the 'group_membership' annotation is expected to be present on `mt`. - :return: Table or MatrixTable with allele frequencies by strata. + :return: Table with annotations of stratified aggregations. """ if entry_agg_funcs is None: - entry_agg_funcs = {} - if select_fields is None: - select_fields = [] - - n_samples = mt.count_cols() - global_expr = {} - if "adj_group" in mt.index_globals(): - global_expr["adj_group"] = mt.index_globals().adj_group - logger.info("Using the 'adj_group' global annotation found on the input MT.") + raise TypeError( + "'agg_by_strata' expects a 'entry_agg_funcs' dictionary but it was not" + " supplied. Without the dictionary, no aggregations will occur." + ) if group_membership_ht is None and "group_membership" not in mt.col: raise ValueError( "The 'group_membership' annotation is not found in the input MatrixTable " "and 'group_membership_ht' is not specified." ) - elif group_membership_ht is None: + + if select_fields is None: + select_fields = [] + + if group_membership_ht is None: logger.info( "'group_membership_ht' is not specified, using sample stratification " "indicated by the 'group_membership' annotation on mt." ) - n_groups = len(mt.group_membership.take(1)[0]) + group_globals = mt.index_globals() else: logger.info( "'group_membership_ht' is specified, using sample stratification indicated " "by its 'group_membership' annotation." ) group_globals = group_membership_ht.index_globals() - n_groups = len(group_membership_ht.group_membership.take(1)[0]) mt = mt.annotate_cols( group_membership=group_membership_ht[mt.col_key].group_membership ) - if "adj_group" not in global_expr: - if "adj_group" in group_globals: - global_expr["adj_group"] = group_globals.adj_group - logger.info( - "Using the 'adj_group' global annotation on 'group_membership_ht'." - ) - elif "freq_meta" in group_globals: - logger.info( - "The 'freq_meta' global annotation is found in " - "'group_membership_ht', using it to determine the adj filtered " - "stratification groups." - ) - freq_meta = group_globals.freq_meta - global_expr["adj_group"] = freq_meta.map( - lambda x: x.get("group", "NA") == "adj" - ) - if "adj_group" not in global_expr: + global_expr = {} + n_groups = len(mt.group_membership.take(1)[0]) + if "adj_group" in group_globals: + global_expr["adj_group"] = group_globals.adj_group + logger.info("Using the 'adj_group' global annotation on 'group_membership_ht'.") + elif "freq_meta" in group_globals: + logger.info( + "The 'freq_meta' global annotation is found in " + "'group_membership_ht', using it to determine the adj filtered " + "stratification groups." + ) + freq_meta = group_globals.freq_meta + global_expr["adj_group"] = freq_meta.map( + lambda x: x.get("group", "NA") == "adj" + ) + else: global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) + # NOTE: Unsure if we still want this check here since the adj_group and n_groups + # always be from the same table or built within this function? Its a cheap operation + # so I'm leaning towards keeping it even though I'm not sure this is the right place + # for this check. n_adj_group = hl.eval(hl.len(global_expr["adj_group"])) - if hl.eval(hl.len(global_expr["adj_group"])) != n_groups: + if n_adj_group != n_groups: raise ValueError( f"The number of elements in the 'adj_group' ({n_adj_group}) global " "annotation does not match the number of elements in the " @@ -2031,13 +2031,11 @@ def agg_by_strata( ) # Keep only the entries needed for the aggregation functions. - select_expr = {} - has_adj = False - if hl.eval(hl.any(global_expr["adj_group"])): + select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}} + has_adj = hl.eval(hl.any(global_expr["adj_group"])) + if has_adj: select_expr["adj"] = mt.adj - has_adj = True - select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}) mt = mt.select_entries(**select_expr) # Convert MT to HT with a row annotation that is an array of all samples entries @@ -2047,7 +2045,7 @@ def agg_by_strata( # For each stratification group in group_membership, determine the indices of the # samples that belong to that group. global_expr["indices_by_group"] = hl.range(n_groups).map( - lambda g_i: hl.range(n_samples).filter( + lambda g_i: hl.range(mt.count_cols()).filter( lambda s_i: ht.cols[s_i].group_membership[g_i] ) )