diff --git a/src/edge_table.rs b/src/edge_table.rs index b1e27f7d5..1cbd26f84 100644 --- a/src/edge_table.rs +++ b/src/edge_table.rs @@ -1,14 +1,15 @@ use crate::bindings as ll_bindings; use crate::metadata; use crate::{tsk_id_t, tsk_size_t, TskitError}; +use crate::{EdgeId, NodeId}; /// Row of an [`EdgeTable`] pub struct EdgeTableRow { - pub id: tsk_id_t, + pub id: EdgeId, pub left: f64, pub right: f64, - pub parent: tsk_id_t, - pub child: tsk_id_t, + pub parent: NodeId, + pub child: NodeId, pub metadata: Option>, } @@ -26,7 +27,7 @@ impl PartialEq for EdgeTableRow { fn make_edge_table_row(table: &EdgeTable, pos: tsk_id_t) -> Option { if pos < table.num_rows() as tsk_id_t { let rv = EdgeTableRow { - id: pos, + id: pos.into(), left: table.left(pos).unwrap(), right: table.right(pos).unwrap(), parent: table.parent(pos).unwrap(), @@ -87,8 +88,8 @@ impl<'a> EdgeTable<'a> { /// /// Will return [``IndexError``](crate::TskitError::IndexError) /// if ``row`` is out of range. - pub fn parent(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.parent) + pub fn parent + Copy>(&'a self, row: E) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.parent, NodeId) } /// Return the ``child`` value from row ``row`` of the table. @@ -97,8 +98,8 @@ impl<'a> EdgeTable<'a> { /// /// Will return [``IndexError``](crate::TskitError::IndexError) /// if ``row`` is out of range. - pub fn child(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.child) + pub fn child + Copy>(&'a self, row: E) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.child, NodeId) } /// Return the ``left`` value from row ``row`` of the table. @@ -107,8 +108,8 @@ impl<'a> EdgeTable<'a> { /// /// Will return [``IndexError``](crate::TskitError::IndexError) /// if ``row`` is out of range. - pub fn left(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.left) + pub fn left + Copy>(&'a self, row: E) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left) } /// Return the ``right`` value from row ``row`` of the table. @@ -117,15 +118,15 @@ impl<'a> EdgeTable<'a> { /// /// Will return [``IndexError``](crate::TskitError::IndexError) /// if ``row`` is out of range. - pub fn right(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.right) + pub fn right + Copy>(&'a self, row: E) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.right) } pub fn metadata( &'a self, - row: tsk_id_t, + row: EdgeId, ) -> Result, TskitError> { - let buffer = metadata_to_vector!(self, row)?; + let buffer = metadata_to_vector!(self, row.0)?; decode_metadata_row!(T, buffer) } @@ -145,7 +146,7 @@ impl<'a> EdgeTable<'a> { /// # Errors /// /// [`TskitError::IndexError`] if `r` is out of range. - pub fn row(&self, r: tsk_id_t) -> Result { - table_row_access!(r, self, make_edge_table_row) + pub fn row + Copy>(&self, r: E) -> Result { + table_row_access!(r.into().0, self, make_edge_table_row) } } diff --git a/src/lib.rs b/src/lib.rs index 8f5414302..cb6b32cb7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -187,12 +187,22 @@ pub struct MutationId(tsk_id_t); #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)] pub struct MigrationId(tsk_id_t); +/// An edge ID +/// +/// This is an integer referring to a row of an [``EdgeTable``]. +/// +/// The features for this type follow the same pattern as for [``NodeId``] +#[repr(transparent)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)] +pub struct EdgeId(tsk_id_t); + impl_id_traits!(NodeId); impl_id_traits!(IndividualId); impl_id_traits!(PopulationId); impl_id_traits!(SiteId); impl_id_traits!(MutationId); impl_id_traits!(MigrationId); +impl_id_traits!(EdgeId); // tskit defines this via a type cast // in a macro. bindgen thus misses it. diff --git a/src/table_collection.rs b/src/table_collection.rs index 03997e80c..ec8564725 100644 --- a/src/table_collection.rs +++ b/src/table_collection.rs @@ -20,7 +20,7 @@ use crate::TreeSequenceFlags; use crate::TskReturnValue; use crate::TskitTypeAccess; use crate::{tsk_flags_t, tsk_id_t, tsk_size_t, TSK_NULL}; -use crate::{IndividualId, MutationId, NodeId, PopulationId, SiteId}; +use crate::{EdgeId, IndividualId, MutationId, NodeId, PopulationId, SiteId}; use ll_bindings::tsk_table_collection_free; /// A table collection. @@ -172,7 +172,7 @@ impl TableCollection { right: f64, parent: P, child: C, - ) -> TskReturnValue { + ) -> Result { self.add_edge_with_metadata(left, right, parent, child, None) } @@ -184,7 +184,7 @@ impl TableCollection { parent: P, child: C, metadata: Option<&dyn MetadataRoundtrip>, - ) -> TskReturnValue { + ) -> Result { let md = EncodedMetadata::new(metadata)?; let rv = unsafe { ll_bindings::tsk_edge_table_add_row( @@ -198,7 +198,7 @@ impl TableCollection { ) }; - handle_tsk_return_value!(rv) + handle_tsk_return_value!(rv, EdgeId::from(rv)) } /// Add a row to the individual table @@ -440,11 +440,11 @@ impl TableCollection { /// If `self.is_indexed()` is `true`, return a non-owning /// slice containing the edge insertion order. /// Otherwise, return `None`. - pub fn edge_insertion_order(&self) -> Option<&[tsk_id_t]> { + pub fn edge_insertion_order(&self) -> Option<&[EdgeId]> { if self.is_indexed() { Some(unsafe { std::slice::from_raw_parts( - (*self.as_ptr()).indexes.edge_insertion_order, + (*self.as_ptr()).indexes.edge_insertion_order as *const EdgeId, (*self.as_ptr()).indexes.num_edges as usize, ) }) @@ -456,11 +456,11 @@ impl TableCollection { /// If `self.is_indexed()` is `true`, return a non-owning /// slice containing the edge removal order. /// Otherwise, return `None`. - pub fn edge_removal_order(&self) -> Option<&[tsk_id_t]> { + pub fn edge_removal_order(&self) -> Option<&[EdgeId]> { if self.is_indexed() { Some(unsafe { std::slice::from_raw_parts( - (*self.as_ptr()).indexes.edge_removal_order, + (*self.as_ptr()).indexes.edge_removal_order as *const EdgeId, (*self.as_ptr()).indexes.num_edges as usize, ) }) @@ -803,6 +803,41 @@ mod test { assert!(*i >= 0); assert!(*i < tables.edges().num_rows() as tsk_id_t); } + + // The "transparent" casts are such black magic that we + // should probably test against what C thinks is going on :) + let input = unsafe { + std::slice::from_raw_parts( + (*tables.as_ptr()).indexes.edge_insertion_order, + (*tables.as_ptr()).indexes.num_edges as usize, + ) + }; + + assert!(!input.is_empty()); + + let tables_input = tables.edge_insertion_order().unwrap(); + + assert_eq!(input.len(), tables_input.len()); + + for i in 0..input.len() { + assert_eq!(EdgeId::from(input[i]), tables_input[i]); + } + + let output = unsafe { + std::slice::from_raw_parts( + (*tables.as_ptr()).indexes.edge_removal_order, + (*tables.as_ptr()).indexes.num_edges as usize, + ) + }; + assert!(!output.is_empty()); + + let tables_output = tables.edge_removal_order().unwrap(); + + assert_eq!(output.len(), tables_output.len()); + + for i in 0..output.len() { + assert_eq!(EdgeId::from(output[i]), tables_output[i]); + } } #[test]