Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix sourmash prefetch to work when db scaled is larger than query scaled #1870

Merged
merged 12 commits into from
Mar 7, 2022
23 changes: 17 additions & 6 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,11 +711,8 @@ def gather(args):
# optionally calculate and save prefetch csv
if prefetch_csvout_fp:
assert scaled
# calculate expected threshold
threshold = args.threshold_bp / scaled

# calculate intersection stats and info
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, threshold)
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, args.threshold_bp)
# remove match and query signatures; write result to prefetch csv
d = dict(prefetch_result._asdict())
del d['match']
Expand Down Expand Up @@ -1168,7 +1165,9 @@ def prefetch(args):
if args.scaled:
notify(f'downsampling query from scaled={query_mh.scaled} to {int(args.scaled)}')
query_mh = query_mh.downsample(scaled=args.scaled)

notify(f"all sketches will be downsampled to scaled={query_mh.scaled}")
common_scaled = query_mh.scaled

# empty?
if not len(query_mh):
Expand Down Expand Up @@ -1223,9 +1222,20 @@ def prefetch(args):
for result in prefetch_database(query, db, args.threshold_bp):
match = result.match

# ensure we're all on the same page wrt scaled resolution:
common_scaled = max(match.minhash.scaled, query.minhash.scaled,
common_scaled)

query_mh = query.minhash.downsample(scaled=common_scaled)
match_mh = match.minhash.downsample(scaled=common_scaled)

if ident_mh.scaled != common_scaled:
ident_mh = ident_mh.downsample(scaled=common_scaled)
if noident_mh.scaled != common_scaled:
noident_mh = noident_mh.downsample(scaled=common_scaled)

# track found & "untouched" hashes.
match_mh = match.minhash.downsample(scaled=query.minhash.scaled)
ident_mh += query.minhash & match_mh.flatten()
ident_mh += query_mh & match_mh.flatten()
noident_mh.remove_many(match.minhash)
ctb marked this conversation as resolved.
Show resolved Hide resolved

# output match info as we go
Expand Down Expand Up @@ -1265,6 +1275,7 @@ def prefetch(args):
assert len(query_mh) == len(ident_mh) + len(noident_mh)
notify(f"of {len(query_mh)} distinct query hashes, {len(ident_mh)} were found in matches above threshold.")
notify(f"a total of {len(noident_mh)} query hashes remain unmatched.")
notify(f"final scaled value (max across query and all matches) is {common_scaled}")

if args.save_matching_hashes:
filename = args.save_matching_hashes
Expand Down
12 changes: 6 additions & 6 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,16 +469,20 @@ def __next__(self):
'intersect_bp, jaccard, max_containment, f_query_match, f_match_query, match, match_filename, match_name, match_md5, match_bp, query, query_filename, query_name, query_md5, query_bp')


def calculate_prefetch_info(query, match, scaled, threshold):
def calculate_prefetch_info(query, match, scaled, threshold_bp):
"""
For a single query and match, calculate all search info and return a PrefetchResult.
"""
# base intersections on downsampled minhashes
query_mh = query.minhash

scaled = max(scaled, match.minhash.scaled)
query_mh = query_mh.downsample(scaled=scaled)
db_mh = match.minhash.flatten().downsample(scaled=scaled)

# calculate db match intersection with query hashes:
intersect_mh = query_mh & db_mh
threshold = threshold_bp / scaled
assert len(intersect_mh) >= threshold

f_query_match = db_mh.contained_by(query_mh)
Expand Down Expand Up @@ -515,12 +519,8 @@ def prefetch_database(query, database, threshold_bp):
scaled = query_mh.scaled
assert scaled

# for testing/double-checking purposes, calculate expected threshold -
threshold = threshold_bp / scaled

# iterate over all signatures in database, find matches

for result in database.prefetch(query, threshold_bp):
match = result.signature
result = calculate_prefetch_info(query, match, scaled, threshold)
result = calculate_prefetch_info(query, match, scaled, threshold_bp)
yield result
61 changes: 61 additions & 0 deletions tests/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,67 @@ def test_prefetch_select_query_ksize(runtmp, linear_gather):
assert 'of 4476 distinct query hashes, 4476 were found in matches above threshold.' in c.last_result.err


def test_prefetch_subject_scaled_is_larger(runtmp, linear_gather):
# test prefetch where query and subject db both have multiple ksizes
ctb marked this conversation as resolved.
Show resolved Hide resolved
c = runtmp

# make a query sketch with scaled=1000
fa = utils.get_test_data('genome-s10.fa.gz')
c.run_sourmash('sketch', 'dna', fa, '-o', 'query.sig')
assert os.path.exists(runtmp.output('query.sig'))

# this has a scaled of 10000, from same genome:
against1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig')
against2 = utils.get_test_data('scaled/all.sbt.zip')
against3 = utils.get_test_data('scaled/all.lca.json')

# run against large scaled, then small (self)
c.run_sourmash('prefetch', 'query.sig', against1, against2, against3,
'query.sig', linear_gather)
print(c.last_result.status)
print(c.last_result.out)
print(c.last_result.err)

assert c.last_result.status == 0
assert 'total of 8 matching signatures.' in c.last_result.err
assert 'of 48 distinct query hashes, 48 were found in matches above threshold.' in c.last_result.err
assert 'final scaled value (max across query and all matches) is 10000' in c.last_result.err


def test_prefetch_subject_scaled_is_larger_outsigs(runtmp, linear_gather):
# test prefetch output sigs
ctb marked this conversation as resolved.
Show resolved Hide resolved
c = runtmp

# make a query sketch with scaled=1000
fa = utils.get_test_data('genome-s10.fa.gz')
c.run_sourmash('sketch', 'dna', fa, '-o', 'query.sig')
assert os.path.exists(runtmp.output('query.sig'))

# this has a scaled of 10000, from same genome:
against1 = utils.get_test_data('scaled/genome-s10.fa.gz.sig')
against2 = utils.get_test_data('scaled/all.sbt.zip')
against3 = utils.get_test_data('scaled/all.lca.json')

# run against large scaled, then small (self)
c.run_sourmash('prefetch', 'query.sig', against1, against2, against3,
'query.sig', linear_gather, '--save-matches', 'matches.sig')
print(c.last_result.status)
print(c.last_result.out)
print(c.last_result.err)

assert c.last_result.status == 0
assert 'total of 8 matching signatures.' in c.last_result.err
assert 'of 48 distinct query hashes, 48 were found in matches above threshold.' in c.last_result.err
assert 'final scaled value (max across query and all matches) is 10000' in c.last_result.err

# make sure non-downsampled sketches were saved.
matches = sourmash.load_file_as_signatures(runtmp.output('matches.sig'))
scaled_vals = set([ match.minhash.scaled for match in matches ])
assert 1000 in scaled_vals
assert 10000 in scaled_vals
assert len(scaled_vals) == 2


def test_prefetch_query_abund(runtmp, linear_gather):
c = runtmp

Expand Down