Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

containment check perf boost #320

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 100 additions & 51 deletions workflows/index-generation/ncbi-compress/src/ncbi_compress.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod ncbi_compress {
use std::cmp::Ordering;
use std::ops::AddAssign;
use std::path::Path;
use std::{borrow::BorrowMut, fs};
Expand All @@ -22,6 +23,51 @@ pub mod ncbi_compress {
Ok(intersect_size as f64 / needle.mins().len() as f64)
}

// check if the needle KmerMinHash is contained in the haystack KmerMinHash by some threshold,
// this is slighly more efficient than using the sourmash implementation of containment and
// comparing it to the threshold because it stops early if containment by the threshold is not
// possible. For high thresholds this should reduce the number of iterations.
pub fn contains(
needle: &KmerMinHash,
haystack: &KmerMinHash,
similarity_threshold: f64,
) -> Result<bool, SourmashError> {
needle.check_compatible(haystack)?;
let mut needle_n = 0;
let mut haystack_n = 0;
let mut intersect = 0;
let target_size = (needle.mins().len() as f64 * similarity_threshold).ceil() as usize;

// Step through the mins of needle and haystack in order, weaving together
while needle_n < needle.mins().len()
&& haystack_n < haystack.mins().len()
// If there aren't enough mins in needle to bring the intersection up to the target
// size then we can stop early. In theory haystack could be smaller than needle so you
// could replace needle.mins().len() - needle_n with the minimum of that and the haystack
// equivalent but because we pre-sort sequences by length that will never happen. Even
// then it would not be wrong it would just iterate slightly more.
&& intersect + needle.mins().len() - needle_n < target_size
{
let needle = needle.mins()[needle_n];
let haystack = haystack.mins()[haystack_n];

match needle.cmp(&haystack) {
Ordering::Less => {
needle_n += 1;
}
Ordering::Greater => {
haystack_n += 1;
}
Ordering::Equal => {
intersect += 1;
needle_n += 1;
haystack_n += 1;
}
};
}
Ok(intersect >= target_size)
}

fn remove_accession_version(accession: &str) -> &str {
accession.splitn(2, |c| c == '.').next().unwrap()
}
Expand Down Expand Up @@ -54,17 +100,21 @@ pub mod ncbi_compress {
log::info!("Creating taxid dir {:?}", taxid_path);
let reader = fasta::Reader::from_file(&input_fasta_path).unwrap();
// Build a trie of the accessions in the input fasta
reader.records().enumerate().par_bridge().for_each(|(i, result)| {
// records.par_iter().for_each(|(i, result)| {
let record = result.as_ref().unwrap();
let accession_id = record.id().split_whitespace().next().unwrap();
let accession_no_version = remove_accession_version(accession_id);
// RocksDB supports concurrent reads and writes so this is safe
accession_to_taxid.put(accession_no_version, b"").unwrap();
if i % 1_000_000 == 0 {
log::info!(" Processed {} accessions", i);
}
});
reader
.records()
.enumerate()
.par_bridge()
.for_each(|(i, result)| {
// records.par_iter().for_each(|(i, result)| {
let record = result.as_ref().unwrap();
let accession_id = record.id().split_whitespace().next().unwrap();
let accession_no_version = remove_accession_version(accession_id);
// RocksDB supports concurrent reads and writes so this is safe
accession_to_taxid.put(accession_no_version, b"").unwrap();
if i % 1_000_000 == 0 {
log::info!(" Processed {} accessions", i);
}
});
log::info!(" Finished loading accessions");

mapping_file_path.par_iter().for_each(|mapping_file_path| {
Expand Down Expand Up @@ -331,7 +381,8 @@ pub mod ncbi_compress {

// Initialize a temporary vector to store the unique items from each chunk
let mut unique_in_chunk: Vec<(KmerMinHash, fasta::Record)> = Vec::with_capacity(chunk_size);
let mut unique_in_tree_and_chunk: Vec<(KmerMinHash, &fasta::Record)> = Vec::with_capacity(chunk_size);
let mut unique_in_tree_and_chunk: Vec<(KmerMinHash, &fasta::Record)> =
Vec::with_capacity(chunk_size);

loop {
let chunk = records_iter
Expand All @@ -347,41 +398,42 @@ pub mod ncbi_compress {

// create signatures for each record in the chunk
let chunk_signatures = chunk
.par_iter()
.map(|r| {
let record = r.as_ref().unwrap();
let mut hash;
if is_protein_fasta {
hash = KmerMinHash::new(
scaled,
k,
HashFunctions::murmur64_protein,
seed,
false,
0,
);
hash.add_protein(record.seq()).unwrap();
(hash, record.clone())
} else {
hash = KmerMinHash::new(
scaled,
k,
HashFunctions::murmur64_DNA,
seed,
false,
0,
);
hash.add_sequence(record.seq(), true).unwrap();
(hash, record.clone())
}
}).collect::<Vec<_>>();
.par_iter()
.map(|r| {
let record = r.as_ref().unwrap();
let mut hash;
if is_protein_fasta {
hash = KmerMinHash::new(
scaled,
k,
HashFunctions::murmur64_protein,
seed,
false,
0,
);
hash.add_protein(record.seq()).unwrap();
(hash, record.clone())
} else {
hash = KmerMinHash::new(
scaled,
k,
HashFunctions::murmur64_DNA,
seed,
false,
0,
);
hash.add_sequence(record.seq(), true).unwrap();
(hash, record.clone())
}
})
.collect::<Vec<_>>();

// we need to make sure records within the chunk arn't similar to each other before
// we check them against the larger tree
for (hash, record) in chunk_signatures {
let similar = unique_in_chunk
.par_iter()
.any(|(other, _record)| containment(&hash, &other).unwrap() >= similarity_threshold);
.any(|(other, _record)| contains(&hash, &other, similarity_threshold).unwrap());

if !similar {
unique_in_chunk.push((hash, record));
Expand All @@ -401,9 +453,7 @@ pub mod ncbi_compress {
// will return them all.
if sketches
.par_iter()
.find_any(|other| {
containment(&hash, other).unwrap() >= similarity_threshold
})
.find_any(|other| contains(&hash, other, similarity_threshold).unwrap())
.is_some()
{
None
Expand Down Expand Up @@ -434,7 +484,6 @@ pub mod ncbi_compress {
// for hash in tmp {
// tree.insert(hash);
// }

}
}
}
Expand Down Expand Up @@ -472,12 +521,9 @@ mod tests {
test_dir_path_str,
);

let entries = fs::read_dir(test_dir_path_str)
.expect("Failed to read directory");
let entries = fs::read_dir(test_dir_path_str).expect("Failed to read directory");

let entries: Vec<_> = entries
.filter_map(Result::ok)
.collect();
let entries: Vec<_> = entries.filter_map(Result::ok).collect();

// Assert that the directory is not empty
assert!(!entries.is_empty(), "Directory is empty");
Expand All @@ -491,7 +537,10 @@ mod tests {

// get the test file path
let test_file_path = format!("{}/{}", test_dir_path_str, truth_file_name);
util::compare_fasta_records_from_files(&test_file_path, &truth_file_path.to_str().unwrap());
util::compare_fasta_records_from_files(
&test_file_path,
&truth_file_path.to_str().unwrap(),
);
}
}

Expand Down
Loading