Skip to content

Commit

Permalink
Change rules of combining MIDI overlap for #412.
Browse files Browse the repository at this point in the history
Use sample sheet to find matching sample names.
  • Loading branch information
donkirkby committed Jan 18, 2018
1 parent f79df1b commit 933e5d5
Showing 1 changed file with 68 additions and 94 deletions.
162 changes: 68 additions & 94 deletions micall/utils/genreport_rerun.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import csv
import os
import re
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from collections import namedtuple, defaultdict
from collections import namedtuple
from csv import DictReader, DictWriter
from itertools import groupby
from operator import itemgetter, attrgetter
from operator import itemgetter

from micall.core.aln2counts import AMINO_ALPHABET
from micall.hivdb.genreport import gen_report
from micall.hivdb.hivdb import hivdb
from micall.utils.sample_sheet_parser import sample_sheet_parser

SampleInfo = namedtuple('SampleInfo', 'enum suffix project snum name')
SampleGroup = namedtuple('SampleGroup', 'enum names')


Expand All @@ -29,49 +28,25 @@ def parse_args():
return parser.parse_args()


def find_groups(working_paths):
groups = defaultdict(list)
def find_groups(working_paths, source_path):
sample_sheet_path = os.path.join(source_path, '../../SampleSheet.csv')
with open(sample_sheet_path) as sample_sheet_file:
run_info = sample_sheet_parser(sample_sheet_file)

midi_files = {row['sample']: row['filename']
for row in run_info['DataSplit']
if row['project'] == 'MidHCV'}
wide_names = {row['filename']: row['sample']
for row in run_info['DataSplit']
if row['project'] == 'HCV'}
for path in working_paths:
basename = os.path.basename(path)
sample_name, snum = basename.split('_')
parts = sample_name.split('-')
for project in ('HCV', 'MidHCV'):
try:
project_index = parts.index(project)
except ValueError:
continue
extraction = parts[project_index-1]
if extraction.endswith('MIDI'):
extraction_num = extraction[:-4]
suffix = 'MIDI'
else:
extraction_num = extraction
suffix = ''
groups[extraction_num].append(SampleInfo(extraction_num,
suffix,
project,
snum,
basename))
break
for extraction_num, samples in sorted(groups.items()):
if len(samples) == 2:
names = tuple(sample.name
for sample in sorted(samples,
key=attrgetter('project')))
yield SampleGroup(extraction_num, names)
else:
print("Couldn't group:", samples)


def parse_sample_info(sample_name):
head, snum = sample_name.split('_')
head, project = head.split('-')
match = re.match(r'([A-Z]*\d+)(.*$)', head)
return SampleInfo(enum=match.group(1),
suffix=match.group(2),
project=project,
snum=snum,
name=sample_name)
wide_file = os.path.basename(path)
sample_name = wide_names.get(wide_file)
if sample_name is None:
# Not an HCV sample.
continue
midi_file = midi_files.get(sample_name + 'MIDI')
yield SampleGroup(sample_name, (wide_file, midi_file))


def rewrite_file(filename):
Expand All @@ -88,67 +63,66 @@ def rewrite_file(filename):


def combine_files(base_path, groups):
amino_columns = list(AMINO_ALPHABET) + ['del', 'coverage']
for group in groups:
if group.names[1] is not None:
combine_midi(base_path, group.names[0], group.names[1])
yield os.path.join(base_path, group.names[0])


def combine_midi(base_path, wide_name, midi_name):
amino_columns = list(AMINO_ALPHABET) + ['del', 'coverage']
src_filename = os.path.join(base_path,
midi_name,
'coverage_scores.csv')
midi_covered_seeds = set()
with open(src_filename) as src:
reader = DictReader(src)
for row in reader:
if (row['region'].endswith('-NS5b') and
row['project'] == 'MidHCV' and
row['on.score'] == '4'):
midi_covered_seeds.add(row['seed'])
break
dest_filename = os.path.join(base_path,
wide_name,
'coverage_scores.csv')
has_good_coverage = False
for row in rewrite_file(dest_filename):
if (row['region'].endswith('-NS5b') and
row['on.score'] == '4'):
if row['seed'] in midi_covered_seeds:
has_good_coverage = True
else:
row['on.score'] = '1'
if has_good_coverage:
dest_filename = os.path.join(base_path,
wide_name,
'amino.csv')
src_filename = os.path.join(base_path,
group.names[1],
'coverage_scores.csv')
midi_covered_seeds = set()
midi_name,
'amino.csv')
with open(src_filename) as src:
reader = DictReader(src)
for row in reader:
if (row['region'].endswith('-NS5b') and
row['project'] == 'MidHCV' and
row['on.score'] == '4'):
midi_covered_seeds.add(row['seed'])
break
dest_filename = os.path.join(base_path,
group.names[0],
'coverage_scores.csv')
has_good_coverage = False
source_rows = {(row['region'], row['refseq.aa.pos']): row
for row in reader
if row['region'].endswith('-NS5b')}
for row in rewrite_file(dest_filename):
if (row['region'].endswith('-NS5b') and
row['on.score'] == '4'):
if row['seed'] in midi_covered_seeds:
has_good_coverage = True
else:
row['on.score'] = '1'
if has_good_coverage:
dest_filename = os.path.join(base_path,
group.names[0],
'amino.csv')
src_filename = os.path.join(base_path,
group.names[1],
'amino.csv')
with open(src_filename) as src:
reader = DictReader(src)
source_rows = {(row['region'], row['refseq.aa.pos']): row
for row in reader
if row['region'].endswith('-NS5b')}
dest_copyname = dest_filename + '.orig.csv'
os.rename(dest_filename, dest_copyname)
with open(dest_copyname) as src, open(dest_filename, 'w') as dest:
reader = DictReader(src)
writer = DictWriter(dest, reader.fieldnames)
writer.writeheader()
for row in reader:
source_row = source_rows.get((row['region'],
row['refseq.aa.pos']),
{})
source_row = source_rows.get((row['region'], row['refseq.aa.pos']))
if source_row is not None:
pos = int(row['refseq.aa.pos'])
wide_coverage = int(row['coverage'])
midi_coverage = int(row['coverage'])
if pos > 335 or (pos >= 226 and midi_coverage > wide_coverage):
for column in amino_columns:
dest_count = int(row[column])
source_count = int(source_row.get(column, '0'))
row[column] = dest_count + source_count
writer.writerow(row)
yield os.path.join(base_path, group.names[0])
row[column] = source_row[column]


def main():
args = parse_args()
working_paths = split_files(args)

sorted_working_paths = sorted(working_paths)
groups = find_groups(sorted_working_paths)
groups = find_groups(sorted_working_paths, args.source)
combined_working_paths = list(combine_files(args.working, groups))
failed_working_paths = set(combined_working_paths)
for working_path in combined_working_paths:
Expand Down

0 comments on commit 933e5d5

Please sign in to comment.