diff --git a/Cargo.toml b/Cargo.toml index 44ce59254..d4d7d3e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ repository = "/~https://github.com/molpopgen/tskit_rust" thiserror = "1.0" libc = "0.2.81" streaming-iterator = "0.1.5" +bitflags = "1.2.1" [dev-dependencies] serde = {version = "1.0.118", features = ["derive"]} diff --git a/src/_macros.rs b/src/_macros.rs index 32c44425f..06a172e08 100644 --- a/src/_macros.rs +++ b/src/_macros.rs @@ -172,6 +172,15 @@ macro_rules! wrapped_tsk_array_traits { }; } +macro_rules! err_if_not_tracking_samples { + ($flags: expr, $rv: expr) => { + match $flags.contains(crate::TreeFlags::SAMPLE_LISTS) { + false => Err(TskitError::NotTrackingSamples), + true => Ok($rv), + } + }; +} + #[cfg(test)] mod test { use crate::error::TskitError; diff --git a/src/error.rs b/src/error.rs index 9725269d1..a9ad1294b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,11 @@ pub enum TskitError { /// arrays allocated on the C side. #[error("Invalid index")] IndexError, + /// Raised when samples are requested from + /// [`Tree`] objects, but sample lists are + /// not being updated. + #[error("Not tracking samples in Trees")] + NotTrackingSamples, /// Wrapper around tskit C API error codes. #[error("{}", get_tskit_error_message(*code))] ErrorCode { code: i32 }, diff --git a/src/lib.rs b/src/lib.rs index 9543abb71..54b2fcfc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ pub use node_table::NodeTable; pub use population_table::PopulationTable; pub use site_table::SiteTable; pub use table_collection::TableCollection; -pub use trees::{NodeIterator, NodeTraversalOrder, Tree, TreeSequence}; +pub use trees::{NodeIterator, NodeTraversalOrder, Tree, TreeFlags, TreeSequence}; /// Handles return codes from low-level tskit functions. /// diff --git a/src/trees.rs b/src/trees.rs index 5ce8ac2e3..f7501af72 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -2,14 +2,25 @@ use crate::bindings as ll_bindings; use crate::error::TskitError; use crate::ffi::{TskitTypeAccess, WrapTskitConsumingType}; use crate::{tsk_flags_t, tsk_id_t, tsk_size_t, TableCollection, TSK_NULL}; +use bitflags::bitflags; use ll_bindings::{tsk_tree_free, tsk_treeseq_free}; use streaming_iterator::StreamingIterator; +bitflags! { + #[derive(Default)] + pub struct TreeFlags: tsk_flags_t { + const NONE = 0; + const SAMPLE_LISTS = ll_bindings::TSK_SAMPLE_LISTS; + const NO_SAMPLE_COUNTS = ll_bindings::TSK_NO_SAMPLE_COUNTS; + } +} + pub struct Tree { inner: Box, current_tree: i32, advanced: bool, num_nodes: tsk_size_t, + flags: TreeFlags, } pub type BoxedNodeIterator = Box; @@ -18,19 +29,20 @@ drop_for_tskit_type!(Tree, tsk_tree_free); tskit_type_access!(Tree, ll_bindings::tsk_tree_t); impl Tree { - fn wrap(num_nodes: tsk_size_t) -> Self { + fn wrap(num_nodes: tsk_size_t, flags: TreeFlags) -> Self { let temp: std::mem::MaybeUninit = std::mem::MaybeUninit::uninit(); Self { inner: unsafe { Box::::new(temp.assume_init()) }, current_tree: 0, advanced: false, num_nodes, + flags, } } - fn new(ts: &TreeSequence) -> Result { - let mut tree = Self::wrap(ts.consumed.nodes().num_rows()); - let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), 0) }; + fn new(ts: &TreeSequence, flags: TreeFlags) -> Result { + let mut tree = Self::wrap(ts.consumed.nodes().num_rows(), flags); + let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) }; handle_tsk_return_value!(rv, tree) } @@ -55,6 +67,36 @@ impl Tree { crate::ffi::TskIdArray::new(self.inner.parent, self.inner.num_nodes) } + fn samples_array(&self) -> Result { + let num_samples = + unsafe { ll_bindings::tsk_treeseq_get_num_samples((*self.as_ptr()).tree_sequence) }; + err_if_not_tracking_samples!( + self.flags, + crate::ffi::TskIdArray::new(self.inner.samples, num_samples) + ) + } + + fn next_sample_array(&self) -> Result { + err_if_not_tracking_samples!( + self.flags, + crate::ffi::TskIdArray::new(self.inner.next_sample, self.inner.num_nodes) + ) + } + + fn left_sample_array(&self) -> Result { + err_if_not_tracking_samples!( + self.flags, + crate::ffi::TskIdArray::new(self.inner.left_sample, self.inner.num_nodes) + ) + } + + fn right_sample_array(&self) -> Result { + err_if_not_tracking_samples!( + self.flags, + crate::ffi::TskIdArray::new(self.inner.right_sample, self.inner.num_nodes) + ) + } + fn left_sib_array(&self) -> crate::ffi::TskIdArray { crate::ffi::TskIdArray::new(self.inner.left_sib, self.inner.num_nodes) } @@ -71,6 +113,20 @@ impl Tree { crate::ffi::TskIdArray::new(self.inner.right_child, self.inner.num_nodes) } + fn left_sample(&self, u: tsk_id_t) -> Result { + if !self.flags.contains(TreeFlags::SAMPLE_LISTS) { + return Err(TskitError::NotTrackingSamples); + } + unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.left_sample); + } + + fn right_sample(&self, u: tsk_id_t) -> Result { + if !self.flags.contains(TreeFlags::SAMPLE_LISTS) { + return Err(TskitError::NotTrackingSamples); + } + unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.right_sample); + } + pub fn interval(&self) -> (f64, f64) { unsafe { ((*self.as_ptr()).left, (*self.as_ptr()).right) } } @@ -104,7 +160,7 @@ impl Tree { unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.right_sib); } - pub fn sample_list(&self) -> Vec { + pub fn samples_to_vec(&self) -> Vec { let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples((*self.as_ptr()).tree_sequence) }; let mut rv = vec![]; @@ -126,6 +182,11 @@ impl Tree { Ok(Box::new(iter)) } + pub fn samples(&self, u: tsk_id_t) -> Result { + let iter = SamplesIterator::new(self, u)?; + Ok(Box::new(iter)) + } + pub fn roots(&self) -> BoxedNodeIterator { Box::new(RootIterator::new(self)) } @@ -361,6 +422,52 @@ impl NodeIterator for PathToRootIterator { } } +struct SamplesIterator { + current_node: Option, + next_sample_index: tsk_id_t, + last_sample_index: tsk_id_t, + next_sample: crate::ffi::TskIdArray, + samples: crate::ffi::TskIdArray, +} + +impl SamplesIterator { + fn new(tree: &Tree, u: tsk_id_t) -> Result { + let rv = SamplesIterator { + current_node: None, + next_sample_index: tree.left_sample(u)?, + last_sample_index: tree.right_sample(u)?, + next_sample: tree.next_sample_array()?, + samples: tree.samples_array()?, + }; + + Ok(rv) + } +} + +impl NodeIterator for SamplesIterator { + fn next_node(&mut self) { + self.current_node = match self.next_sample_index { + TSK_NULL => None, + r => { + if r == self.last_sample_index { + let cr = Some(self.samples[r]); + self.next_sample_index = TSK_NULL; + cr + } else { + assert!(r >= 0); + let cr = Some(self.samples[r]); + self.next_sample_index = self.next_sample[r]; + cr + } + } + }; + } + + fn current_node(&mut self) -> Option { + self.current_node + } +} + /// A tree sequence. /// /// This is a thin wrapper around the C type `tsk_treeseq_t`. @@ -415,13 +522,13 @@ impl TreeSequence { self.consumed.deepcopy() } - pub fn tree_iterator(&self) -> Result { - let tree = Tree::new(self)?; + pub fn tree_iterator(&self, flags: TreeFlags) -> Result { + let tree = Tree::new(self, flags)?; Ok(tree) } - pub fn sample_list(&self) -> Vec { + pub fn samples_to_vec(&self) -> Vec { let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }; let mut rv = vec![]; @@ -461,7 +568,7 @@ mod test_trees { fn test_create_treeseq_new_from_tables() { let tables = make_small_table_collection(); let treeseq = TreeSequence::new(tables).unwrap(); - let samples = treeseq.sample_list(); + let samples = treeseq.samples_to_vec(); assert_eq!(samples.len(), 2); for i in 1..3 { assert_eq!(samples[i - 1], i as tsk_id_t); @@ -479,11 +586,11 @@ mod test_trees { let tables = make_small_table_collection(); let treeseq = tables.tree_sequence().unwrap(); let mut ntrees = 0; - let mut tree_iter = treeseq.tree_iterator().unwrap(); + let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap(); while let Some(tree) = tree_iter.next() { ntrees += 1; assert_eq!(tree.current_tree, ntrees); - let samples = tree.sample_list(); + let samples = tree.samples_to_vec(); assert_eq!(samples.len(), 2); for i in 1..3 { assert_eq!(samples[i - 1], i as tsk_id_t); @@ -511,7 +618,7 @@ mod test_trees { let mut tables = TableCollection::new(100.).unwrap(); tables.build_index(0).unwrap(); let treeseq = tables.tree_sequence().unwrap(); - let mut tree_iter = treeseq.tree_iterator().unwrap(); + let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap(); while let Some(tree) = tree_iter.next() { let mut num_roots = 0; for _ in tree.roots() { @@ -520,4 +627,46 @@ mod test_trees { assert_eq!(num_roots, 0); } } + + #[should_panic] + #[test] + fn test_samples_iterator_error_when_not_tracking_samples() { + let tables = make_small_table_collection(); + let treeseq = tables.tree_sequence().unwrap(); + + let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap(); + if let Some(tree) = tree_iter.next() { + for n in tree.nodes(NodeTraversalOrder::Preorder) { + for _ in tree.samples(n).unwrap() {} + } + } + } + + #[test] + fn test_iterate_samples() { + let tables = make_small_table_collection(); + let treeseq = tables.tree_sequence().unwrap(); + + let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap(); + if let Some(tree) = tree_iter.next() { + let mut s = vec![]; + for i in tree.samples(0).unwrap() { + s.push(i); + } + assert_eq!(s.len(), 2); + assert_eq!(s[0], 1); + assert_eq!(s[1], 2); + + for u in 1..3 { + let mut s = vec![]; + for i in tree.samples(u).unwrap() { + s.push(i); + } + assert_eq!(s.len(), 1); + assert_eq!(s[0], u); + } + } else { + panic!("Expected a tree"); + } + } } diff --git a/tskit_rust_examples/tree_traversals.rs b/tskit_rust_examples/tree_traversals.rs index e6d7edbb1..3deb731c6 100644 --- a/tskit_rust_examples/tree_traversals.rs +++ b/tskit_rust_examples/tree_traversals.rs @@ -5,7 +5,7 @@ use tskit::NodeIterator; // "Manual" traversal from samples to root fn traverse_upwards(tree: &tskit::Tree) -> () { - let samples = tree.sample_list(); + let samples = tree.samples_to_vec(); for s in samples.iter() { let mut u = *s; @@ -17,7 +17,7 @@ fn traverse_upwards(tree: &tskit::Tree) -> () { // Iterate from each node up to its root. fn traverse_upwards_with_closure(tree: &tskit::Tree) -> () { - let samples = tree.sample_list(); + let samples = tree.samples_to_vec(); for s in samples.iter() { let mut steps_to_root = 0; @@ -48,7 +48,7 @@ fn main() { let treeseq = tskit::TreeSequence::load(&treefile).unwrap(); - let mut tree_iterator = treeseq.tree_iterator().unwrap(); + let mut tree_iterator = treeseq.tree_iterator(tskit::TreeFlags::default()).unwrap(); while let Some(tree) = tree_iterator.next() { traverse_upwards(&tree);