diff --git a/gnomad_qc/v4/assessment/calculate_per_sample_stats.py b/gnomad_qc/v4/assessment/calculate_per_sample_stats.py index 58f032f27..88d6ff7f6 100644 --- a/gnomad_qc/v4/assessment/calculate_per_sample_stats.py +++ b/gnomad_qc/v4/assessment/calculate_per_sample_stats.py @@ -24,15 +24,22 @@ """ import argparse import logging -from typing import Optional +from copy import deepcopy +from typing import Dict, Optional import hail as hl -from gnomad.utils.filtering import filter_low_conf_regions +from gnomad.assessment.summary_stats import ( + get_summary_stats_csq_filter_expr, + get_summary_stats_filter_group_meta, + get_summary_stats_variant_filter_expr, +) +from gnomad.utils.annotations import annotate_with_ht from gnomad.utils.slack import slack_notifications from gnomad.utils.vep import ( CSQ_CODING, CSQ_NON_CODING, LOF_CSQ_SET, + LOFTEE_LABELS, filter_vep_transcript_csqs, get_most_severe_consequence_for_summary, ) @@ -54,206 +61,302 @@ logger = logging.getLogger("per_sample_stats") logger.setLevel(logging.INFO) +# V4 has no OS in LOFTEE_LABELS. +LOFTEE_LABELS = deepcopy(LOFTEE_LABELS) +LOFTEE_LABELS.remove("OS") + +SUM_STAT_FILTERS = { + "variant_qc": ["none", "pass"], # Quality control status of the variant. + "capture": [ # Capture methods used. + "ukb", + "broad", + "ukb_broad_intersect", + "ukb_broad_union", + ], + "max_af": [0.0001, 0.001, 0.01], # Maximum allele frequency thresholds. + "csq_set": ["lof", "coding", "non_coding"], # Consequence sets. + "lof_csq": deepcopy(LOF_CSQ_SET), # Loss-of-function consequence set. + "csq": [ # Additional consequence types. + "missense_variant", + "synonymous_variant", + "intron_variant", + "intergenic_variant", + ], +} +""" +Dictionary of default filter settings for summary stats. +""" -def create_per_sample_counts_ht( - mt: hl.MatrixTable, - annotation_ht: hl.Table, +COMMON_FILTERS = {"variant_qc": ["pass"], "capture": ["ukb_broad_intersect"]} +""" +Dictionary of common filter settings to use for most summary stats. +""" + +COMMON_FILTER_COMBOS = [["variant_qc"], ["variant_qc", "capture"]] +""" +List of common variant filter combinations to use for summary stats. +""" + +LOF_FILTERS_FOR_COMBO = { + "lof_csq_set": ["lof"], # Loss-of-function consequence set. + "loftee_label": deepcopy(LOFTEE_LABELS), # LOFTEE loss-of-function labels. + "loftee_HC": ["HC"], # High-confidence LOFTEE label. + "loftee_flags": ["no_flags", "with_flags"], # High-confidence LOFTEE flag options. +} +""" +Dictionary of an additional filter group to use for loss-of-function filter +combinations. +""" + +LOF_FILTER_COMBOS = [ + ["lof_csq", "loftee_label"], + ["lof_csq_set", "loftee_label"], + ["lof_csq_set", "loftee_HC", "loftee_flags"], + ["lof_csq", "loftee_HC", "loftee_flags"], +] +""" +List of loss-of-function consequence combinations to use for summary stats. +""" + +MAP_FILTER_FIELD_TO_META = { + "lof_csq": "csq", + "loftee_HC": "loftee_label", + "lof_csq_set": "csq_set", +} +""" +Dictionary to rename keys in `SUM_STAT_FILTERS`, `COMMON_FILTERS`, or +`LOF_FILTERS_FOR_COMBO` to final metadata keys. +""" + + +def get_capture_filter_exprs( + ht: hl.Table, + ukb_capture: bool = False, + broad_capture: bool = False, +) -> Dict[str, hl.expr.BooleanExpression]: + """ + Get filter expressions for UK Biobank and Broad capture regions. + + :param ht: Table containing variant annotations. The following annotations are + required: 'region_flags'. + :param ukb_capture: Expression for variants that are in UKB capture intervals. + :param broad_capture: Expression for variants that are in Broad capture intervals. + :return: Dictionary of filter expressions for UK Biobank and Broad capture regions. + """ + filter_expr = {} + log_list = [] + if ukb_capture: + log_list.append("variants in UK Biobank capture regions") + filter_expr["capture_ukb"] = ~ht.region_flags.outside_ukb_capture_region + + if broad_capture: + log_list.append("variants in Broad capture regions") + filter_expr["capture_broad"] = ~ht.region_flags.outside_broad_capture_region + + if ukb_capture and broad_capture: + log_list.append("variants in the intersect of UKB and Broad capture regions") + filter_expr["capture_ukb_broad_intersect"] = ( + filter_expr["capture_ukb"] & filter_expr["capture_broad"] + ) + + log_list.append("variants in the union of UKB and Broad capture regions") + filter_expr["capture_ukb_broad_union"] = ( + filter_expr["capture_ukb"] | filter_expr["capture_broad"] + ) + + logger.info("Adding filtering for:\n\t%s...", "\n\t".join(log_list)) + + return filter_expr + + +def get_summary_stats_filter_groups_ht( + ht: hl.Table, pass_filters: bool = False, ukb_capture: bool = False, broad_capture: bool = False, by_csqs: bool = False, - rare_variants: bool = False, vep_canonical: bool = True, vep_mane: bool = False, rare_variants_afs: Optional[list[float]] = None, ) -> hl.Table: """ - Create Table of Hail's sample_qc output broken down by requested variant groupings. + Create Table annotated with an array of booleans indicating whether a variant belongs to certain filter groups. - Useful for finding the number of variants per sample, either all variants, or - variants fall into specific capture regions, or variants that are rare - (adj AF <0.1%), or variants categorized by predicted consequences in all, canonical - or mane transcripts. + A 'filter_groups' annotation is added to the Table containing an ArrayExpression of + BooleanExpressions for each requested filter group. - :param mt: Input MatrixTable containing variant data. Must have multi-allelic sites - split. - :param annotation_ht: Table containing variant annotations. The following - annotations are required: 'freq', 'filters', and 'region_flags'. If `by_csqs` is - True, 'vep' is also required. + A 'filter_group_meta' global annotation is added to the Table containing an array + of dictionaries detailing the filters used in each filter group. + + :param ht: Table containing variant annotations. The following annotations are + required: 'freq', 'filters', and 'region_flags'. If `by_csqs` is True, 'vep' is + also required. :param pass_filters: Include count of variants that pass all variant QC filters. :param ukb_capture: Include count of variants that are in UKB capture intervals. :param broad_capture: Include count of variants that are in Broad capture intervals :param by_csqs: Include count of variants by variant consequence: loss-of-function, missense, and synonymous. - :param rare_variants: Include count of rare variants, defined as those which have - adj AF <0.1%. :param vep_canonical: If `by_csqs` is True, filter to only canonical transcripts. If trying count variants in all transcripts, set it to False. Default is True. :param vep_mane: If `by_csqs` is True, filter to only MANE transcripts. Default is False. :param rare_variants_afs: The allele frequency thresholds to use for rare variants. - :return: Table containing per-sample variant counts. + :return: Table containing an ArrayExpression of filter groups for summary stats. """ - logger.info("Filtering out low confidence regions...") - mt = filter_low_conf_regions(mt, filter_decoy=False) - - logger.info("Filtering input MT to variants in the supplied annotation HT...") - mt = mt.semi_join_rows(annotation_ht) - - # Add extra Allele Count and Allele Type annotations to variant MatrixTable, - # according to Hail standards, to help their computation. - variant_ac, variant_types = vmt_sample_qc_variant_annotations( - global_gt=mt.GT, alleles=mt.alleles - ) - - mt = mt.annotate_rows( - variant_ac=variant_ac, - variant_atypes=variant_types, - ) - - keep_annotations = ["freq", "filters", "region_flags"] + csq_filter_expr = {} if by_csqs: - annotation_ht = filter_vep_transcript_csqs( - annotation_ht, + # Filter to only canonical or MANE transcripts if requested and get the most + # severe consequence for each variant. + ht = filter_vep_transcript_csqs( + ht, synonymous=False, canonical=vep_canonical, mane_select=vep_mane, ) - annotation_ht = get_most_severe_consequence_for_summary(annotation_ht) - keep_annotations.extend(["most_severe_csq", "lof", "no_lof_flags"]) + ht = get_most_severe_consequence_for_summary(ht) - # Annotate the MT with the needed annotations. - annotation_ht = annotation_ht.select(*keep_annotations).checkpoint( - hl.utils.new_temp_file("annotation_ht", "ht") - ) - mt = mt.annotate_rows(**annotation_ht[mt.row_key]) + # Create filter expressions for the requested consequence types. + csq_filter_expr.update( + get_summary_stats_csq_filter_expr( + ht, + lof_csq_set=LOF_CSQ_SET, + lof_label_set=LOFTEE_LABELS, + lof_no_flags=True, + lof_any_flags=True, + additional_csq_sets={ + "coding": set(CSQ_CODING), + "non_coding": set(CSQ_NON_CODING), + }, + additional_csqs=set(SUM_STAT_FILTERS["csq"]), + ) + ) - filter_expr = {"all_variants": True} + # Create filter expressions for LCR, variant QC filters, and rare variant AFs if + # requested. + filter_exprs = { + "all_variants": hl.literal(True), + **get_capture_filter_exprs(ht, ukb_capture, broad_capture), + **get_summary_stats_variant_filter_expr( + ht, + filter_lcr=True, + filter_expr=ht.filters if pass_filters else None, + freq_expr=ht.freq[0].AF, + max_af=rare_variants_afs, + ), + **csq_filter_expr, + } + + # Create the metadata for all requested filter groups. + ss_filters = deepcopy(SUM_STAT_FILTERS) + ss_filters["max_af"] = rare_variants_afs + filter_group_meta = get_summary_stats_filter_group_meta( + ss_filters, + common_filter_combos=COMMON_FILTER_COMBOS, + common_filter_override=COMMON_FILTERS, + lof_filter_combos=LOF_FILTER_COMBOS, + lof_filter_override=LOF_FILTERS_FOR_COMBO, + filter_key_rename=MAP_FILTER_FIELD_TO_META, + ) - if pass_filters: - logger.info("Filtering to variants that pass all variant QC filters...") - filter_expr["pass_filters"] = hl.len(mt.filters) == 0 - if ukb_capture: - logger.info("Filtering to variants in UK Biobank capture regions...") - filter_expr["ukb_capture"] = ~mt.region_flags.outside_ukb_capture_region - if broad_capture: - logger.info("Filtering to variants in Broad capture regions...") - filter_expr["broad_capture"] = ~mt.region_flags.outside_broad_capture_region - if ukb_capture and broad_capture: - logger.info( - "Filtering to variants in the intersect of UKB and Broad capture regions..." - ) - filter_expr["ukb_broad_capture_intersect"] = ( - filter_expr["ukb_capture"] & filter_expr["broad_capture"] - ) - logger.info( - "Filtering to variants in the union of UKB and Broad capture regions..." - ) - filter_expr["ukb_broad_capture_union"] = ( - filter_expr["ukb_capture"] | filter_expr["broad_capture"] - ) - if all([ukb_capture, broad_capture, pass_filters]): - logger.info( - "Filtering to variants in the union of UKB and Broad capture regions and" - " pass filters..." - ) - filter_expr["ukb_broad_capture_union_pass_filters"] = ( - filter_expr["ukb_broad_capture_union"] & filter_expr["pass_filters"] - ) - if rare_variants: - for af in rare_variants_afs: - logger.info(f"Filtering to rare variants with adj AF <{af}...") - filter_expr[f"rare_{af}"] = mt.freq[0].AF < af - if by_csqs: - logger.info("Filtering to variants by consequence type...") - - def create_filter_by_csq( - csq_set: hl.set, lof_label: str = None, no_lof_flag: bool = None - ) -> hl.expr.BooleanExpression: - """ - Create filters based on consequence, labels, and flags. Always filters to variants which pass all variant qc filters. - - :param csq_set: Set of consequence types to filter by. - :param lof_label: Label to filter by loss-of-function annotations. - :param no_lof_flag: Flag to filter by loss-of-function annotations. - :return: Filter expression. - """ - base_filter = filter_expr["pass_filters"] & hl.any( - lambda csq: mt.most_severe_csq == csq, csq_set - ) - if lof_label: - base_filter &= mt.lof == lof_label - if no_lof_flag is not None: - base_filter &= mt.no_lof_flags == no_lof_flag - return base_filter - - filter_expr["coding"] = create_filter_by_csq(set(CSQ_CODING)) - filter_expr["non_coding"] = create_filter_by_csq(set(CSQ_NON_CODING)) - filter_expr["lof"] = create_filter_by_csq(LOF_CSQ_SET) - - for lof_label in ["HC", "LC", "OS"]: - filter_expr[f"lof_{lof_label}"] = create_filter_by_csq( - LOF_CSQ_SET, lof_label + # Create filter expressions for each filter group by combining the filter + # expressions for each filter in the filter group metadata. + filter_groups_expr = [] + final_meta = [] + for filter_group in filter_group_meta: + # Initialize filter expression for the filter group with True to allow for + # a filtering group that has no filters, e.g. all variants. + filter_expr = hl.literal(True) + filter_group_requested = True + for k, v in filter_group.items(): + # Rename "loftee_flags" to "loftee" to match the filter expression keys. + k = k.replace("loftee_flags", "loftee") + # Determine the correct key for filter_expr, it can be a combination of + # the key and value, or just the key followed by using the value to get the + # filter expression from a struct. + f_expr = filter_exprs.get(f"{k}_{v}") + f_struct = filter_exprs.get(k) + # If the filter group is in the combinations, but not filter_exprs, then + # the filter group was not in the requested list. + if f_expr is None and f_struct is None: + filter_group_requested = False + break + filter_expr &= f_struct[v] if f_expr is None else f_expr + + if filter_group_requested: + filter_groups_expr.append(filter_expr) + final_meta.append(filter_group) + else: + logger.warning( + "Filter group %s was not requested and will not be included in the " + "summary stats.", + filter_group, ) - for no_lof_flag in [True, False]: - flag_desc = "with" if not no_lof_flag else "no" - filter_expr[f"lof_HC_{flag_desc}_flags"] = create_filter_by_csq( - LOF_CSQ_SET, lof_label="HC", no_lof_flag=no_lof_flag - ) + # Remove 'no_lcr' filter expression from filter groups and annotate the Table with + # the no_lcr filter and an array of the filter groups. + ht = ht.select(_no_lcr=filter_exprs["no_lcr"], filter_groups=filter_groups_expr) - # LOF variants breakdowns - for lof_variant in LOF_CSQ_SET: - for no_lof_flag in [True, False]: - flag_desc = "with" if not no_lof_flag else "no" - filter_expr[f"{lof_variant}_HC_{flag_desc}_flags"] = ( - create_filter_by_csq( - {lof_variant}, lof_label="HC", no_lof_flag=no_lof_flag - ) - ) - for lof_label in ["LC", "OS"]: - filter_expr[f"{lof_variant}_{lof_label}"] = create_filter_by_csq( - {lof_variant}, lof_label=lof_label - ) + ht = ht.select_globals(filter_group_meta=final_meta) + logger.info("Filter groups for summary stats: %s", filter_group_meta) + + # Filter to only variants that are not in low confidence regions. + ht = ht.filter(ht._no_lcr).drop("_no_lcr") + ht = ht.checkpoint(hl.utils.new_temp_file("stats_annotation", "ht")) + + return ht + + +def create_per_sample_counts_ht( + mt: hl.MatrixTable, filter_group_ht: hl.Table +) -> hl.Table: + """ + Create Table of Hail's sample_qc output broken down by requested variant groupings. + + Useful for finding the number of variants per sample, either all variants, or + variants fall into specific capture regions, or variants that are rare + (adj AF <0.1%), or variants categorized by predicted consequences in all, canonical + or mane transcripts. - for csq in [ - "missense_variant", - "synonymous_variant", - "intron_variant", - "intergenic_variant", - ]: - filter_expr[csq] = create_filter_by_csq({csq}) + :param mt: Input MatrixTable containing variant data. Must have multi-allelic sites + split. + :param filter_group_ht: Table containing filter groups for summary stats. + :return: Table containing per-sample variant counts. + """ + # Add extra Allele Count and Allele Type annotations to variant MatrixTable, + # according to Hail standards, to help their computation. + variant_ac, variant_types = vmt_sample_qc_variant_annotations( + global_gt=mt.GT, alleles=mt.alleles + ) + mt = mt.annotate_rows(variant_ac=variant_ac, variant_atypes=variant_types) + + # Annotate the MT with the needed annotations. + mt = annotate_with_ht(mt, filter_group_ht, filter_missing=True) # Run Hail's 'vmt_sample_qc' for all requested filter groups. + qc_expr = vmt_sample_qc( + global_gt=mt.GT, + gq=mt.GQ, + variant_ac=mt.variant_ac, + variant_atypes=mt.variant_atypes, + dp=mt.DP, + ) ht = mt.select_cols( - _sample_qc=hl.struct( - **{ - ann: hl.agg.filter( - expr, - vmt_sample_qc( - global_gt=mt.GT, - gq=mt.GQ, - variant_ac=mt.variant_ac, - variant_atypes=mt.variant_atypes, - dp=mt.DP, - ), - ) - for ann, expr in filter_expr.items() - } + summary_stats=hl.agg.array_agg( + lambda f: hl.agg.filter(f, qc_expr), mt.filter_groups ) ).cols() - - ht = ht.select(**ht._sample_qc) + ht = ht.annotate_globals( + summary_stats_meta=filter_group_ht.index_globals().filter_group_meta + ) + ht = ht.checkpoint(hl.utils.new_temp_file("per_sample_counts", "ht")) # Add 'n_indel' to the output Table. - for field in ht.row_value: - ht = ht.annotate( - **{ - field: ht[field].annotate( - n_indel=ht[field].n_insertion + ht[field].n_deletion - ) - } + ht = ht.annotate( + summary_stats=ht.summary_stats.map( + lambda x: x.annotate(n_indel=x.n_insertion + x.n_deletion) ) + ) + return ht @@ -274,51 +377,39 @@ def compute_agg_sample_stats( working on "exomes" data. :return: Struct of aggregate statistics for per-sample QC metrics. """ - if by_ancestry and meta_ht is None: - raise ValueError( - "If `by_ancestry` is True, a Table containing sample metadata is required." - ) - - subset = ( - ["gnomad"] - if not by_subset - else [ - "gnomad", - hl.if_else(meta_ht[ht.s].project_meta.ukb_sample, "ukb", "non-ukb"), - ] - ) - gen_anc = ( - ["global"] - if not by_ancestry - else ["global", meta_ht[ht.s].population_inference.pop] - ) - - all_strats = [ - strat - for strat in ht.row_value - if isinstance(ht[strat], hl.expr.StructExpression) - ] + if meta_ht is None and by_ancestry: + raise ValueError("If `by_ancestry` is True, `meta_ht` is required.") + if meta_ht is None and by_subset: + raise ValueError("If `by_subset` is True, `meta_ht` is required.") + + subset = ["gnomad"] + gen_anc = ["global"] + if meta_ht is not None: + meta_s = meta_ht[ht.s] + subset_expr = hl.if_else(meta_s.project_meta.ukb_sample, "ukb", "non-ukb") + subset += [subset_expr] if by_subset else [] + gen_anc += [meta_s.population_inference.pop] if by_ancestry else [] ht = ht.transmute( subset=subset, gen_anc=gen_anc, - stats_array=[(strat, ht[strat]) for strat in all_strats], + summary_stats=hl.zip(ht.summary_stats_meta, ht.summary_stats), ) - ht = ht.explode("stats_array").explode("gen_anc").explode("subset") + ht = ht.explode("summary_stats").explode("gen_anc").explode("subset") - ht = ht.group_by("subset", "gen_anc", variant_filter=ht.stats_array[0]).aggregate( + ht = ht.group_by("subset", "gen_anc", variant_filter=ht.summary_stats[0]).aggregate( **{ metric: hl.struct( - mean=hl.agg.mean(ht.stats_array[1][metric]), - min=hl.agg.min(ht.stats_array[1][metric]), - max=hl.agg.max(ht.stats_array[1][metric]), + mean=hl.agg.mean(ht.summary_stats[1][metric]), + min=hl.agg.min(ht.summary_stats[1][metric]), + max=hl.agg.max(ht.summary_stats[1][metric]), quantiles=hl.agg.approx_quantiles( - ht.stats_array[1][metric], [0.0, 0.25, 0.5, 0.75, 1.0] + ht.summary_stats[1][metric], [0.0, 0.25, 0.5, 0.75, 1.0] ), ) - for metric in ht.stats_array[1] - if isinstance(ht.stats_array[1][metric], hl.expr.NumericExpression) + for metric in ht.summary_stats[1] + if isinstance(ht.summary_stats[1][metric], hl.expr.NumericExpression) } ) @@ -338,6 +429,9 @@ def main(args): data_type = args.data_type test = args.test overwrite = args.overwrite + ukb_capture_intervals = not args.skip_filter_ukb_capture_intervals + broad_capture_intervals = not args.skip_filter_broad_capture_intervals + rare_variants_afs = args.rare_variants_afs if not args.skip_rare_variants else None per_sample_res = get_per_sample_counts( test=test, data_type=data_type, suffix=args.custom_suffix ) @@ -365,6 +459,8 @@ def main(args): mt = get_gnomad_v4_genomes_vds( test=test, release_only=True, split=True, chrom=chrom ).variant_data + ukb_capture_intervals = False + broad_capture_intervals = False release_ht = release_sites(data_type=data_type).ht() if test: @@ -372,26 +468,19 @@ def main(args): release_ht, [hl.parse_locus_interval("chr22")] ) - create_per_sample_counts_ht( - mt, + filter_groups_ht = get_summary_stats_filter_groups_ht( release_ht, pass_filters=not args.skip_pass_filters, - ukb_capture=( - not args.skip_filter_ukb_capture_intervals - if data_type == "exomes" - else False - ), - broad_capture=( - not args.skip_filter_broad_capture_intervals - if data_type == "exomes" - else False - ), + ukb_capture=ukb_capture_intervals, + broad_capture=broad_capture_intervals, by_csqs=not args.skip_by_csqs, - rare_variants=not args.skip_rare_variants, vep_canonical=args.vep_canonical, vep_mane=args.vep_mane, - rare_variants_afs=args.rare_variants_afs, - ).write(per_sample_res.path, overwrite=overwrite) + rare_variants_afs=rare_variants_afs, + ) + create_per_sample_counts_ht(mt, filter_groups_ht).write( + per_sample_res.path, overwrite=overwrite + ) if args.aggregate_sample_stats: logger.info("Computing aggregate sample statistics...") @@ -477,7 +566,7 @@ def main(args): parser.add_argument( "--rare-variants-afs", type=float, - default=[0.0001, 0.001, 0.01], + default=SUM_STAT_FILTERS["max_af"], help="The allele frequency threshold to use for rare variants.", ) parser.add_argument(