Skip to content

Commit

Permalink
refactor: Add sys::LLTree
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen committed May 22, 2023
1 parent 6698012 commit 85e048e
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 63 deletions.
2 changes: 2 additions & 0 deletions src/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod bindings;

pub mod flags;
mod tables;
mod tree;
mod treeseq;

// tskit defines this via a type cast
Expand All @@ -17,6 +18,7 @@ pub(crate) const TSK_NULL: bindings::tsk_id_t = -1;

pub use tables::*;
pub use treeseq::LLTreeSeq;
pub use tree::LLTree;

#[non_exhaustive]
#[derive(Error, Debug)]
Expand Down
68 changes: 68 additions & 0 deletions src/sys/tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::ptr::NonNull;

use mbox::MBox;

use super::bindings::tsk_tree_t;
use super::flags::TreeFlags;
use super::Error;
use super::LLTreeSeq;

pub struct LLTree<'treeseq> {
inner: MBox<tsk_tree_t>,
// NOTE: this reference exists becaust tsk_tree_t
// contains a NON-OWNING pointer to tsk_treeseq_t.
// Thus, we could theoretically cause UB without
// tying the rust-side object liftimes together.
#[allow(dead_code)]
treeseq: &'treeseq LLTreeSeq,
}

impl<'treeseq> LLTree<'treeseq> {
pub fn new(treeseq: &'treeseq LLTreeSeq, flags: TreeFlags) -> Result<Self, Error> {
// SAFETY: this is the type we want :)
let temp = unsafe {
libc::malloc(std::mem::size_of::<super::bindings::tsk_tree_t>())
as *mut super::bindings::tsk_tree_t
};

// Get our pointer into MBox ASAP
let nonnull = NonNull::<super::bindings::tsk_tree_t>::new(temp)
.ok_or_else(|| Error::Message("failed to malloc tsk_tree_t".to_string()))?;

// SAFETY: if temp is NULL, we have returned Err already.
let mut inner = unsafe { mbox::MBox::from_non_null_raw(nonnull) };
let mut rv = unsafe {
super::bindings::tsk_tree_init(inner.as_mut(), treeseq.as_ptr(), flags.bits())
};
if rv < 0 {
return Err(Error::Code(rv));
}
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
// SAFETY: nobody is null here.
rv = unsafe {
super::bindings::tsk_tree_set_tracked_samples(
inner.as_mut(),
treeseq.num_samples(),
(inner.as_mut()).samples,
)
};
if rv < 0 {
return Err(Error::Code(rv));
}
}
Ok(Self { inner, treeseq })
}

pub fn as_mut_ptr(&mut self) -> *mut tsk_tree_t {
MBox::<tsk_tree_t>::as_mut_ptr(&mut self.inner)
}
}

impl<'treeseq> Drop for LLTree<'treeseq> {
fn drop(&mut self) {
// SAFETY: Mbox<_> cannot hold a NULL ptr
let rv = unsafe { super::bindings::tsk_tree_free(self.inner.as_mut()) };
assert_eq!(rv, 0);
}
}
7 changes: 7 additions & 0 deletions src/sys/treeseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ impl LLTreeSeq {
unsafe { bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }
}

pub fn num_nodes_raw(&self) -> bindings::tsk_size_t {
assert!(!self.as_ptr().is_null());
assert!(!unsafe { *self.as_ptr() }.tables.is_null());
// SAFETY: none of the pointers are null
unsafe { (*(*self.as_ptr()).tables).nodes.num_rows }
}

pub fn kc_distance(&self, other: &Self, lambda: f64) -> Result<f64, Error> {
let mut kc: f64 = f64::NAN;
let kcp: *mut f64 = &mut kc;
Expand Down
76 changes: 14 additions & 62 deletions src/trees/tree.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
use std::ops::Deref;
use std::ops::DerefMut;

use super::TreeSequence;
use crate::sys::bindings as ll_bindings;
use crate::sys::{LLTree, LLTreeSeq};
use crate::TreeFlags;
use crate::TreeInterface;
use crate::TskitError;
use ll_bindings::tsk_tree_free;
use std::ptr::NonNull;

/// A Tree.
///
/// Wrapper around `tsk_tree_t`.
pub struct Tree<'treeseq> {
pub(crate) inner: mbox::MBox<ll_bindings::tsk_tree_t>,
// NOTE: this reference exists becaust tsk_tree_t
// contains a NON-OWNING pointer to tsk_treeseq_t.
// Thus, we could theoretically cause UB without
// tying the rust-side object liftimes together.
#[allow(dead_code)]
treeseq: &'treeseq TreeSequence,
pub(crate) inner: LLTree<'treeseq>,
api: TreeInterface,
current_tree: i32,
advanced: bool,
}

impl<'treeseq> Drop for Tree<'treeseq> {
fn drop(&mut self) {
// SAFETY: Mbox<_> cannot hold a NULL ptr
let rv = unsafe { tsk_tree_free(self.inner.as_mut()) };
assert_eq!(rv, 0);
}
}

impl<'treeseq> Deref for Tree<'treeseq> {
type Target = TreeInterface;
fn deref(&self) -> &Self::Target {
Expand All @@ -48,62 +32,30 @@ impl<'treeseq> DerefMut for Tree<'treeseq> {

impl<'treeseq> Tree<'treeseq> {
pub(crate) fn new<F: Into<TreeFlags>>(
ts: &'treeseq TreeSequence,
ts: &'treeseq LLTreeSeq,
flags: F,
) -> Result<Self, TskitError> {
let flags = flags.into();

// SAFETY: this is the type we want :)
let temp = unsafe {
libc::malloc(std::mem::size_of::<ll_bindings::tsk_tree_t>())
as *mut ll_bindings::tsk_tree_t
};

// Get our pointer into MBox ASAP
let nonnull = NonNull::<ll_bindings::tsk_tree_t>::new(temp)
.ok_or_else(|| TskitError::LibraryError("failed to malloc tsk_tree_t".to_string()))?;

// SAFETY: if temp is NULL, we have returned Err already.
let mut tree = unsafe { mbox::MBox::from_non_null_raw(nonnull) };
let mut rv =
unsafe { ll_bindings::tsk_tree_init(tree.as_mut(), ts.as_ptr(), flags.bits()) };
if rv < 0 {
return Err(TskitError::ErrorCode { code: rv });
}
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
// SAFETY: nobody is null here.
rv = unsafe {
ll_bindings::tsk_tree_set_tracked_samples(
tree.as_mut(),
ts.num_samples().into(),
(tree.as_mut()).samples,
)
};
}

let num_nodes = unsafe { (*(*ts.as_ptr()).tables).nodes.num_rows };
let mut inner = LLTree::new(ts, flags)?;
let nonnull = std::ptr::NonNull::new(inner.as_mut_ptr()).unwrap();
let num_nodes = ts.num_nodes_raw();
let api = TreeInterface::new(nonnull, num_nodes, num_nodes + 1, flags);
handle_tsk_return_value!(
rv,
Tree {
inner: tree,
treeseq: ts,
current_tree: 0,
advanced: false,
api
}
)
Ok(Self {
inner,
current_tree: 0,
advanced: false,
api,
})
}
}

impl<'ts> streaming_iterator::StreamingIterator for Tree<'ts> {
type Item = Tree<'ts>;
fn advance(&mut self) {
let rv = if self.current_tree == 0 {
unsafe { ll_bindings::tsk_tree_first(self.as_mut_ptr()) }
unsafe { ll_bindings::tsk_tree_first(self.inner.as_mut_ptr()) }
} else {
unsafe { ll_bindings::tsk_tree_next(self.as_mut_ptr()) }
unsafe { ll_bindings::tsk_tree_next(self.inner.as_mut_ptr()) }
};
if rv == 0 {
self.advanced = false;
Expand Down
2 changes: 1 addition & 1 deletion src/trees/treeseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl TreeSequence {
/// }
/// ```
pub fn tree_iterator<F: Into<TreeFlags>>(&self, flags: F) -> Result<Tree, TskitError> {
let tree = Tree::new(self, flags)?;
let tree = Tree::new(&self.inner, flags)?;

Ok(tree)
}
Expand Down

0 comments on commit 85e048e

Please sign in to comment.