Skip to content

Commit

Permalink
Merge pull request #41 from molpopgen/sample_list_traversal
Browse files Browse the repository at this point in the history
Add support for sample list traversal.
  • Loading branch information
molpopgen authored Apr 8, 2021
2 parents f1acfb0 + cb0e9a4 commit 40e8989
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 16 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repository = "/~https://github.com/molpopgen/tskit_rust"
thiserror = "1.0"
libc = "0.2.81"
streaming-iterator = "0.1.5"
bitflags = "1.2.1"

[dev-dependencies]
serde = {version = "1.0.118", features = ["derive"]}
Expand Down
9 changes: 9 additions & 0 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ macro_rules! wrapped_tsk_array_traits {
};
}

macro_rules! err_if_not_tracking_samples {
($flags: expr, $rv: expr) => {
match $flags.contains(crate::TreeFlags::SAMPLE_LISTS) {
false => Err(TskitError::NotTrackingSamples),
true => Ok($rv),
}
};
}

#[cfg(test)]
mod test {
use crate::error::TskitError;
Expand Down
5 changes: 5 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ pub enum TskitError {
/// arrays allocated on the C side.
#[error("Invalid index")]
IndexError,
/// Raised when samples are requested from
/// [`Tree`] objects, but sample lists are
/// not being updated.
#[error("Not tracking samples in Trees")]
NotTrackingSamples,
/// Wrapper around tskit C API error codes.
#[error("{}", get_tskit_error_message(*code))]
ErrorCode { code: i32 },
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use node_table::NodeTable;
pub use population_table::PopulationTable;
pub use site_table::SiteTable;
pub use table_collection::TableCollection;
pub use trees::{NodeIterator, NodeTraversalOrder, Tree, TreeSequence};
pub use trees::{NodeIterator, NodeTraversalOrder, Tree, TreeFlags, TreeSequence};

/// Handles return codes from low-level tskit functions.
///
Expand Down
173 changes: 161 additions & 12 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@ use crate::bindings as ll_bindings;
use crate::error::TskitError;
use crate::ffi::{TskitTypeAccess, WrapTskitConsumingType};
use crate::{tsk_flags_t, tsk_id_t, tsk_size_t, TableCollection, TSK_NULL};
use bitflags::bitflags;
use ll_bindings::{tsk_tree_free, tsk_treeseq_free};
use streaming_iterator::StreamingIterator;

bitflags! {
#[derive(Default)]
pub struct TreeFlags: tsk_flags_t {
const NONE = 0;
const SAMPLE_LISTS = ll_bindings::TSK_SAMPLE_LISTS;
const NO_SAMPLE_COUNTS = ll_bindings::TSK_NO_SAMPLE_COUNTS;
}
}

pub struct Tree {
inner: Box<ll_bindings::tsk_tree_t>,
current_tree: i32,
advanced: bool,
num_nodes: tsk_size_t,
flags: TreeFlags,
}

pub type BoxedNodeIterator = Box<dyn NodeIterator>;
Expand All @@ -18,19 +29,20 @@ drop_for_tskit_type!(Tree, tsk_tree_free);
tskit_type_access!(Tree, ll_bindings::tsk_tree_t);

impl Tree {
fn wrap(num_nodes: tsk_size_t) -> Self {
fn wrap(num_nodes: tsk_size_t, flags: TreeFlags) -> Self {
let temp: std::mem::MaybeUninit<ll_bindings::tsk_tree_t> = std::mem::MaybeUninit::uninit();
Self {
inner: unsafe { Box::<ll_bindings::tsk_tree_t>::new(temp.assume_init()) },
current_tree: 0,
advanced: false,
num_nodes,
flags,
}
}

fn new(ts: &TreeSequence) -> Result<Self, TskitError> {
let mut tree = Self::wrap(ts.consumed.nodes().num_rows());
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), 0) };
fn new(ts: &TreeSequence, flags: TreeFlags) -> Result<Self, TskitError> {
let mut tree = Self::wrap(ts.consumed.nodes().num_rows(), flags);
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
handle_tsk_return_value!(rv, tree)
}

Expand All @@ -55,6 +67,36 @@ impl Tree {
crate::ffi::TskIdArray::new(self.inner.parent, self.inner.num_nodes)
}

fn samples_array(&self) -> Result<crate::ffi::TskIdArray, TskitError> {
let num_samples =
unsafe { ll_bindings::tsk_treeseq_get_num_samples((*self.as_ptr()).tree_sequence) };
err_if_not_tracking_samples!(
self.flags,
crate::ffi::TskIdArray::new(self.inner.samples, num_samples)
)
}

fn next_sample_array(&self) -> Result<crate::ffi::TskIdArray, TskitError> {
err_if_not_tracking_samples!(
self.flags,
crate::ffi::TskIdArray::new(self.inner.next_sample, self.inner.num_nodes)
)
}

fn left_sample_array(&self) -> Result<crate::ffi::TskIdArray, TskitError> {
err_if_not_tracking_samples!(
self.flags,
crate::ffi::TskIdArray::new(self.inner.left_sample, self.inner.num_nodes)
)
}

fn right_sample_array(&self) -> Result<crate::ffi::TskIdArray, TskitError> {
err_if_not_tracking_samples!(
self.flags,
crate::ffi::TskIdArray::new(self.inner.right_sample, self.inner.num_nodes)
)
}

fn left_sib_array(&self) -> crate::ffi::TskIdArray {
crate::ffi::TskIdArray::new(self.inner.left_sib, self.inner.num_nodes)
}
Expand All @@ -71,6 +113,20 @@ impl Tree {
crate::ffi::TskIdArray::new(self.inner.right_child, self.inner.num_nodes)
}

fn left_sample(&self, u: tsk_id_t) -> Result<tsk_id_t, TskitError> {
if !self.flags.contains(TreeFlags::SAMPLE_LISTS) {
return Err(TskitError::NotTrackingSamples);
}
unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.left_sample);
}

fn right_sample(&self, u: tsk_id_t) -> Result<tsk_id_t, TskitError> {
if !self.flags.contains(TreeFlags::SAMPLE_LISTS) {
return Err(TskitError::NotTrackingSamples);
}
unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.right_sample);
}

pub fn interval(&self) -> (f64, f64) {
unsafe { ((*self.as_ptr()).left, (*self.as_ptr()).right) }
}
Expand Down Expand Up @@ -104,7 +160,7 @@ impl Tree {
unsafe_tsk_column_access!(u, 0, self.num_nodes, self.inner.right_sib);
}

pub fn sample_list(&self) -> Vec<tsk_id_t> {
pub fn samples_to_vec(&self) -> Vec<tsk_id_t> {
let num_samples =
unsafe { ll_bindings::tsk_treeseq_get_num_samples((*self.as_ptr()).tree_sequence) };
let mut rv = vec![];
Expand All @@ -126,6 +182,11 @@ impl Tree {
Ok(Box::new(iter))
}

pub fn samples(&self, u: tsk_id_t) -> Result<BoxedNodeIterator, TskitError> {
let iter = SamplesIterator::new(self, u)?;
Ok(Box::new(iter))
}

pub fn roots(&self) -> BoxedNodeIterator {
Box::new(RootIterator::new(self))
}
Expand Down Expand Up @@ -361,6 +422,52 @@ impl NodeIterator for PathToRootIterator {
}
}

struct SamplesIterator {
current_node: Option<tsk_id_t>,
next_sample_index: tsk_id_t,
last_sample_index: tsk_id_t,
next_sample: crate::ffi::TskIdArray,
samples: crate::ffi::TskIdArray,
}

impl SamplesIterator {
fn new(tree: &Tree, u: tsk_id_t) -> Result<Self, TskitError> {
let rv = SamplesIterator {
current_node: None,
next_sample_index: tree.left_sample(u)?,
last_sample_index: tree.right_sample(u)?,
next_sample: tree.next_sample_array()?,
samples: tree.samples_array()?,
};

Ok(rv)
}
}

impl NodeIterator for SamplesIterator {
fn next_node(&mut self) {
self.current_node = match self.next_sample_index {
TSK_NULL => None,
r => {
if r == self.last_sample_index {
let cr = Some(self.samples[r]);
self.next_sample_index = TSK_NULL;
cr
} else {
assert!(r >= 0);
let cr = Some(self.samples[r]);
self.next_sample_index = self.next_sample[r];
cr
}
}
};
}

fn current_node(&mut self) -> Option<tsk_id_t> {
self.current_node
}
}

/// A tree sequence.
///
/// This is a thin wrapper around the C type `tsk_treeseq_t`.
Expand Down Expand Up @@ -415,13 +522,13 @@ impl TreeSequence {
self.consumed.deepcopy()
}

pub fn tree_iterator(&self) -> Result<Tree, TskitError> {
let tree = Tree::new(self)?;
pub fn tree_iterator(&self, flags: TreeFlags) -> Result<Tree, TskitError> {
let tree = Tree::new(self, flags)?;

Ok(tree)
}

pub fn sample_list(&self) -> Vec<tsk_id_t> {
pub fn samples_to_vec(&self) -> Vec<tsk_id_t> {
let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
let mut rv = vec![];

Expand Down Expand Up @@ -461,7 +568,7 @@ mod test_trees {
fn test_create_treeseq_new_from_tables() {
let tables = make_small_table_collection();
let treeseq = TreeSequence::new(tables).unwrap();
let samples = treeseq.sample_list();
let samples = treeseq.samples_to_vec();
assert_eq!(samples.len(), 2);
for i in 1..3 {
assert_eq!(samples[i - 1], i as tsk_id_t);
Expand All @@ -479,11 +586,11 @@ mod test_trees {
let tables = make_small_table_collection();
let treeseq = tables.tree_sequence().unwrap();
let mut ntrees = 0;
let mut tree_iter = treeseq.tree_iterator().unwrap();
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
while let Some(tree) = tree_iter.next() {
ntrees += 1;
assert_eq!(tree.current_tree, ntrees);
let samples = tree.sample_list();
let samples = tree.samples_to_vec();
assert_eq!(samples.len(), 2);
for i in 1..3 {
assert_eq!(samples[i - 1], i as tsk_id_t);
Expand Down Expand Up @@ -511,7 +618,7 @@ mod test_trees {
let mut tables = TableCollection::new(100.).unwrap();
tables.build_index(0).unwrap();
let treeseq = tables.tree_sequence().unwrap();
let mut tree_iter = treeseq.tree_iterator().unwrap();
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
while let Some(tree) = tree_iter.next() {
let mut num_roots = 0;
for _ in tree.roots() {
Expand All @@ -520,4 +627,46 @@ mod test_trees {
assert_eq!(num_roots, 0);
}
}

#[should_panic]
#[test]
fn test_samples_iterator_error_when_not_tracking_samples() {
let tables = make_small_table_collection();
let treeseq = tables.tree_sequence().unwrap();

let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
if let Some(tree) = tree_iter.next() {
for n in tree.nodes(NodeTraversalOrder::Preorder) {
for _ in tree.samples(n).unwrap() {}
}
}
}

#[test]
fn test_iterate_samples() {
let tables = make_small_table_collection();
let treeseq = tables.tree_sequence().unwrap();

let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
if let Some(tree) = tree_iter.next() {
let mut s = vec![];
for i in tree.samples(0).unwrap() {
s.push(i);
}
assert_eq!(s.len(), 2);
assert_eq!(s[0], 1);
assert_eq!(s[1], 2);

for u in 1..3 {
let mut s = vec![];
for i in tree.samples(u).unwrap() {
s.push(i);
}
assert_eq!(s.len(), 1);
assert_eq!(s[0], u);
}
} else {
panic!("Expected a tree");
}
}
}
6 changes: 3 additions & 3 deletions tskit_rust_examples/tree_traversals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tskit::NodeIterator;

// "Manual" traversal from samples to root
fn traverse_upwards(tree: &tskit::Tree) -> () {
let samples = tree.sample_list();
let samples = tree.samples_to_vec();

for s in samples.iter() {
let mut u = *s;
Expand All @@ -17,7 +17,7 @@ fn traverse_upwards(tree: &tskit::Tree) -> () {

// Iterate from each node up to its root.
fn traverse_upwards_with_closure(tree: &tskit::Tree) -> () {
let samples = tree.sample_list();
let samples = tree.samples_to_vec();

for s in samples.iter() {
let mut steps_to_root = 0;
Expand Down Expand Up @@ -48,7 +48,7 @@ fn main() {

let treeseq = tskit::TreeSequence::load(&treefile).unwrap();

let mut tree_iterator = treeseq.tree_iterator().unwrap();
let mut tree_iterator = treeseq.tree_iterator(tskit::TreeFlags::default()).unwrap();

while let Some(tree) = tree_iterator.next() {
traverse_upwards(&tree);
Expand Down

0 comments on commit 40e8989

Please sign in to comment.