From feb4df0b29b354f58b3471f251028eaa7b97eb7d Mon Sep 17 00:00:00 2001 From: don Date: Wed, 17 Jan 2018 13:34:29 -0800 Subject: [PATCH] Fix broken tests for #412. Add utility for rerunning resistance analysis. --- micall/tests/test_asi_algorithm.py | 20 +-- micall/tests/test_hivdb.py | 35 +---- micall/utils/genreport_rerun.py | 197 +++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 38 deletions(-) create mode 100644 micall/utils/genreport_rerun.py diff --git a/micall/tests/test_asi_algorithm.py b/micall/tests/test_asi_algorithm.py index 04424d619..06c71271a 100644 --- a/micall/tests/test_asi_algorithm.py +++ b/micall/tests/test_asi_algorithm.py @@ -517,8 +517,8 @@ def test_interpret(self): references=self.references) aa_seq = [['A']] * 40 + [['L']] + [['A']] * 59 compared_attrs = ('code', 'score', 'level', 'level_name') - expected_drugs = [('MDP', 4.0, 3, 'Resistance Possible'), - ('SIL', 4.0, 3, 'Resistance Possible')] + expected_drugs = [('MDP', 4.0, 4, 'Resistance Possible'), + ('SIL', 4.0, 4, 'Resistance Possible')] expected_mutation_comments = [] result = asi.interpret(aa_seq, 'HCV1A-H77-NS5a') @@ -532,8 +532,8 @@ def test_score_and_flag(self): references=self.references) aa_seq = [['A']] * 41 + [['L']] * 2 + [['A']] * 57 compared_attrs = ('code', 'score', 'level', 'level_name') - expected_drugs = [('MDP', 8.0, 4, 'Resistance Likely'), - ('SIL', 8.0, 4, 'Resistance Likely')] + expected_drugs = [('MDP', 8.0, 5, 'Resistance Likely'), + ('SIL', 8.0, 5, 'Resistance Likely')] expected_mutation_comments = [] result = asi.interpret(aa_seq, 'HCV1B-Con1-NS5a') @@ -547,8 +547,8 @@ def test_flag_only(self): references=self.references) aa_seq = [['A']] * 42 + [['L']] + [['A']] * 57 compared_attrs = ('code', 'score', 'level', 'level_name') - expected_drugs = [('MDP', 0.0, 2, 'Mutations Detected; Effect Unknown'), - ('SIL', 0.0, 2, 'Mutations Detected; Effect Unknown')] + expected_drugs = [('MDP', 0.0, 3, 'Mutations Detected; Effect Unknown'), + ('SIL', 0.0, 3, 'Mutations Detected; Effect Unknown')] expected_mutation_comments = [] result = asi.interpret(aa_seq, 'HCV1B-Con1-NS5a') @@ -562,8 +562,8 @@ def test_not_indicated(self): references=self.references) aa_seq = [['A']] * 100 compared_attrs = ('code', 'score', 'level', 'level_name') - expected_drugs = [('MDP', 0.0, -1, 'Not Indicated'), - ('SIL', 0.0, -1, 'Not Indicated')] + expected_drugs = [('MDP', 0.0, 2, 'Not Indicated'), + ('SIL', 0.0, 2, 'Not Indicated')] expected_mutation_comments = [] result = asi.interpret(aa_seq, 'HCV2-JFH-1-NS5a') @@ -577,8 +577,8 @@ def test_not_available(self): references=self.references) aa_seq = [['A']] * 100 compared_attrs = ('code', 'score', 'level', 'level_name') - expected_drugs = [('MDP', 0.0, -2, 'Resistance Interpretation Not Available'), - ('SIL', 0.0, -1, 'Not Indicated')] + expected_drugs = [('MDP', 0.0, -1, 'Resistance Interpretation Not Available'), + ('SIL', 0.0, 2, 'Not Indicated')] expected_mutation_comments = [] result = asi.interpret(aa_seq, 'HCV3-S52-NS5a') diff --git a/micall/tests/test_hivdb.py b/micall/tests/test_hivdb.py index 929639728..1087dcc50 100644 --- a/micall/tests/test_hivdb.py +++ b/micall/tests/test_hivdb.py @@ -514,13 +514,17 @@ def test_hcv(self): mutations_csv = StringIO() expected_resistance = """\ region,drug_class,drug,drug_name,level,level_name,score,genotype -NS5b,NS5b,DSV,DSV,-1,Not Indicated,0.0,6 -NS5b,NS5b,SOF-EPC,SOF-EPC,4,Resistance Likely,8.0,6 -NS5b,NS5b,SOF-HAR,SOF-HAR,-1,Not Indicated,0.0,6 +NS5b,NS5b,DSV,DSV,2,Not Indicated,0.0,6 +NS5b,NS5b,SOF-EPC,SOF-EPC,5,Resistance Likely,8.0,6 +NS5b,NS5b,SOF-HAR,SOF-HAR,2,Not Indicated,0.0,6 """ expected_mutations = """\ drug_class,mutation,prevalence,genotype +NS5b,L159T,1.0,6 NS5b,S282T,1.0,6 +NS5b,C316T,1.0,6 +NS5b,L320T,1.0,6 +NS5b,V321T,1.0,6 """ write_resistance(aminos, resistance_csv, mutations_csv) @@ -545,31 +549,6 @@ def test_hcv_low_coverage(self): self.assertEqual(expected_resistance, resistance_csv.getvalue()) self.assertEqual(expected_mutations, mutations_csv.getvalue()) - def test_multiple_mutations(self): - self.maxDiff = None - aminos = [AminoList('HCV1A-H77-NS5b', - [{'R': 1.0}] * 591, - '1A')] - resistance_csv = StringIO() - mutations_csv = StringIO() - expected_resistance = """\ -region,drug_class,drug,drug_name,level,level_name,score,genotype -NS5b,NS5b,DSV,DSV,4,Resistance Likely,16.0,1A -NS5b,NS5b,SOF-EPC,SOF-EPC,2,Mutations Detected; Effect Unknown,0.0,1A -NS5b,NS5b,SOF-HAR,SOF-HAR,2,Mutations Detected; Effect Unknown,0.0,1A -""" - expected_mutations = """\ -drug_class,mutation,prevalence,genotype -NS5b,G307R,1.0,1A -NS5b,S556R,1.0,1A -NS5b,G558R,1.0,1A -""" - - write_resistance(aminos, resistance_csv, mutations_csv) - - self.assertEqual(expected_resistance, resistance_csv.getvalue()) - self.assertEqual(expected_mutations, mutations_csv.getvalue()) - class GenotypeTest(TestCase): def test_hiv(self): diff --git a/micall/utils/genreport_rerun.py b/micall/utils/genreport_rerun.py new file mode 100644 index 000000000..c853408a9 --- /dev/null +++ b/micall/utils/genreport_rerun.py @@ -0,0 +1,197 @@ +import csv +import os +import re +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from collections import namedtuple, defaultdict +from csv import DictReader, DictWriter +from itertools import groupby +from operator import itemgetter, attrgetter + +from micall.core.aln2counts import AMINO_ALPHABET +from micall.hivdb.genreport import gen_report +from micall.hivdb.hivdb import hivdb + +SampleInfo = namedtuple('SampleInfo', 'enum suffix project snum name') +SampleGroup = namedtuple('SampleGroup', 'enum names') + + +def parse_args(): + parser = ArgumentParser( + description='Rerun resistance interpretations on a run folder.', + formatter_class=ArgumentDefaultsHelpFormatter) + + parser.add_argument('--source', + '-s', + help='source results folder') + parser.add_argument('--working', + '-w', + help='working folder') + return parser.parse_args() + + +def find_groups(working_paths): + groups = defaultdict(list) + for path in working_paths: + basename = os.path.basename(path) + sample_name, snum = basename.split('_') + parts = sample_name.split('-') + for project in ('HCV', 'MidHCV'): + try: + project_index = parts.index(project) + except ValueError: + continue + extraction = parts[project_index-1] + if extraction.endswith('MIDI'): + extraction_num = extraction[:-4] + suffix = 'MIDI' + else: + extraction_num = extraction + suffix = '' + groups[extraction_num].append(SampleInfo(extraction_num, + suffix, + project, + snum, + basename)) + break + for extraction_num, samples in sorted(groups.items()): + if len(samples) == 2: + names = tuple(sample.name + for sample in sorted(samples, + key=attrgetter('project'))) + yield SampleGroup(extraction_num, names) + else: + print("Couldn't group:", samples) + + +def parse_sample_info(sample_name): + head, snum = sample_name.split('_') + head, project = head.split('-') + match = re.match(r'([A-Z]*\d+)(.*$)', head) + return SampleInfo(enum=match.group(1), + suffix=match.group(2), + project=project, + snum=snum, + name=sample_name) + + +def combine_files(base_path, groups): + amino_columns = list(AMINO_ALPHABET) + ['del', 'coverage'] + for group in groups: + src_filename = os.path.join(base_path, + group.names[1], + 'coverage_scores.csv') + with open(src_filename) as src: + reader = DictReader(src) + has_good_coverage = False + for row in reader: + if row['region'].endswith('-NS5b') and row['on.score'] == '4': + has_good_coverage = True + break + if has_good_coverage: + dest_filename = os.path.join(base_path, + group.names[0], + 'amino.csv') + src_filename = os.path.join(base_path, + group.names[1], + 'amino.csv') + with open(src_filename) as src: + reader = DictReader(src) + source_rows = {(row['region'], row['refseq.aa.pos']): row + for row in reader + if row['region'].endswith('-NS5b')} + dest_copyname = dest_filename + '.orig.csv' + os.rename(dest_filename, dest_copyname) + with open(dest_copyname) as src, open(dest_filename, 'w') as dest: + reader = DictReader(src) + writer = DictWriter(dest, reader.fieldnames) + writer.writeheader() + for row in reader: + source_row = source_rows.get((row['region'], + row['refseq.aa.pos']), + {}) + for column in amino_columns: + dest_count = int(row[column]) + source_count = int(source_row.get(column, '0')) + row[column] = dest_count + source_count + writer.writerow(row) + yield os.path.join(base_path, group.names[0]) + + +def main(): + args = parse_args() + working_paths = split_files(args) + + sorted_working_paths = sorted(working_paths) + groups = find_groups(sorted_working_paths) + combined_working_paths = list(combine_files(args.working, groups)) + failed_working_paths = set(combined_working_paths) + for working_path in combined_working_paths: + print(working_path) + with open(os.path.join(working_path, 'amino.csv')) as amino_csv, \ + open(os.path.join(working_path, 'coverage_scores.csv')) as coverage_scores_csv, \ + open(os.path.join(working_path, 'resistance.csv'), 'w') as resistance_csv, \ + open(os.path.join(working_path, 'mutations.csv'), 'w') as mutations_csv: + hivdb(amino_csv, + coverage_scores_csv, + resistance_csv, + mutations_csv) + sample_name = os.path.basename(working_path) + with open(os.path.join(working_path, 'resistance.csv')) as resistance_csv, \ + open(os.path.join(working_path, 'mutations.csv')) as mutations_csv, \ + open(os.path.join(working_path, 'resistance_report.pdf'), 'wb') as resistance_report_csv: + gen_report(resistance_csv, + mutations_csv, + resistance_report_csv, + sample_name=sample_name) + + for file_name in ('resistance.csv', 'mutations.csv'): + with open(os.path.join(args.working, file_name), 'w') as dest: + dest_writer = csv.writer(dest) + for i, working_path in enumerate(combined_working_paths): + sample_name = os.path.basename(working_path) + with open(os.path.join(working_path, file_name), 'r') as source: + source_reader = csv.reader(source) + for j, row in enumerate(source_reader): + if j != 0: + row.insert(0, sample_name) + elif i == 0: + row.insert(0, 'sample') + else: + continue + if j == 1: + failed_working_paths.discard(working_path) + dest_writer.writerow(row) + + with open(os.path.join(args.working, 'failed.csv'), 'w') as dest: + dest_writer = csv.writer(dest) + dest_writer.writerow(['sample']) + for working_path in sorted(failed_working_paths): + dest_writer.writerow([os.path.basename(working_path)]) + + +def split_files(args): + working_paths = set() + for file_name in ('amino.csv', 'coverage_scores.csv'): + file_path = os.path.join(args.source, file_name) + with open(file_path) as f: + reader = DictReader(f) + for sample, rows in groupby(reader, itemgetter('sample')): + working_path = os.path.join(args.working, sample) + working_paths.add(working_path) + if __name__ == '__live_coding__': + if len(working_paths) > 20: + break + continue + os.makedirs(working_path, exist_ok=True) + target_path = os.path.join(working_path, file_name) + with open(target_path, 'w') as target_csv: + writer = DictWriter(target_csv, + reader.fieldnames[1:]) + writer.writeheader() + for row in rows: + del row['sample'] + writer.writerow(row) + return working_paths + + +main()