diff --git a/src/sourmash/tax/tax_utils.py b/src/sourmash/tax/tax_utils.py index b071539f0b..6828b61ecb 100644 --- a/src/sourmash/tax/tax_utils.py +++ b/src/sourmash/tax/tax_utils.py @@ -85,8 +85,7 @@ def collect_gather_csvs(cmdline_gather_input, *, from_file=None): return gather_csvs -def load_gather_results(gather_csv, *, delimiter=',', - essential_colnames=EssentialGatherColnames, +def load_gather_results(gather_csv, *, essential_colnames=EssentialGatherColnames, seen_queries=None, force=False): "Load a single gather csv" if not seen_queries: @@ -94,8 +93,7 @@ def load_gather_results(gather_csv, *, delimiter=',', header = [] gather_results = [] gather_queries = set() - with open(gather_csv, 'rt') as fp: - r = csv.DictReader(fp, delimiter=delimiter) + with sourmash_args.FileInputCSV(gather_csv) as r: header = r.fieldnames # check for empty file if not header: diff --git a/tests/test_tax.py b/tests/test_tax.py index b9ac15c171..6ff6ffdf4b 100644 --- a/tests/test_tax.py +++ b/tests/test_tax.py @@ -2010,6 +2010,40 @@ def test_annotate_0(runtmp): assert "d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri" in lin_gather_results[4] +def test_annotate_gzipped_gather(runtmp): + # test annotate basics + c = runtmp + + g_csv = utils.get_test_data('tax/test1.gather.csv') + # rewrite gather_csv as gzipped csv + gz_gather = runtmp.output('test1.gather.csv.gz') + with open(g_csv, 'rb') as f_in, gzip.open(gz_gather, 'wb') as f_out: + f_out.writelines(f_in) + + tax = utils.get_test_data('tax/test.taxonomy.csv') + csvout = runtmp.output("test1.gather.with-lineages.csv") + out_dir = os.path.dirname(csvout) + + c.run_sourmash('tax', 'annotate', '--gather-csv', gz_gather, '--taxonomy-csv', tax, '-o', out_dir) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert os.path.exists(csvout) + + lin_gather_results = [x.rstrip() for x in open(csvout)] + print("\n".join(lin_gather_results)) + assert f"saving 'annotate' output to '{csvout}'" in runtmp.last_result.err + + assert "lineage" in lin_gather_results[0] + assert "d__Bacteria;p__Proteobacteria;c__Gammaproteobacteria;o__Enterobacterales;f__Enterobacteriaceae;g__Escherichia;s__Escherichia coli" in lin_gather_results[1] + assert "d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri" in lin_gather_results[2] + assert "d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Phocaeicola;s__Phocaeicola vulgatus" in lin_gather_results[3] + assert "d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri" in lin_gather_results[4] + + def test_annotate_gather_argparse(runtmp): # test annotate with two gather CSVs, second one empty, and --force. # this tests argparse handling w/extend. diff --git a/tests/test_tax_utils.py b/tests/test_tax_utils.py index 5461b62a5a..e2f4838476 100644 --- a/tests/test_tax_utils.py +++ b/tests/test_tax_utils.py @@ -156,6 +156,17 @@ def test_load_gather_results(): assert len(gather_results) == 4 +def test_load_gather_results_gzipped(runtmp): + gather_csv = utils.get_test_data('tax/test1.gather.csv') + + # rewrite gather_csv as gzipped csv + gz_gather = runtmp.output('g.csv.gz') + with open(gather_csv, 'rb') as f_in, gzip.open(gz_gather, 'wb') as f_out: + f_out.writelines(f_in) + gather_results, header, seen_queries = load_gather_results(gz_gather) + assert len(gather_results) == 4 + + def test_load_gather_results_bad_header(runtmp): g_csv = utils.get_test_data('tax/test1.gather.csv')