From 9baff077b5f6dd247b6e8c01905a391ad23e5650 Mon Sep 17 00:00:00 2001 From: "Kevin R. Thornton" Date: Wed, 21 Jul 2021 10:48:52 -0700 Subject: [PATCH] Add MigrationId and ProvenanceId (#137) --- src/lib.rs | 23 ++++++++++++++++ src/migration_table.rs | 58 ++++++++++++++++++++++++++--------------- src/provenance.rs | 31 ++++++++++++++-------- src/table_collection.rs | 4 +-- src/trees.rs | 4 +-- 5 files changed, 84 insertions(+), 36 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4028d81e1..8f5414302 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -178,11 +178,21 @@ pub struct SiteId(tsk_id_t); #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)] pub struct MutationId(tsk_id_t); +/// A migration ID +/// +/// This is an integer referring to a row of an [``MigrationTable``]. +/// +/// The features for this type follow the same pattern as for [``NodeId``] +#[repr(transparent)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)] +pub struct MigrationId(tsk_id_t); + impl_id_traits!(NodeId); impl_id_traits!(IndividualId); impl_id_traits!(PopulationId); impl_id_traits!(SiteId); impl_id_traits!(MutationId); +impl_id_traits!(MigrationId); // tskit defines this via a type cast // in a macro. bindgen thus misses it. @@ -210,6 +220,19 @@ pub use trees::{NodeTraversalOrder, Tree, TreeSequence}; #[cfg(any(doc, feature = "provenance"))] pub mod provenance; +/// A provenance ID +/// +/// This is an integer referring to a row of an [``ProvenanceTable``]. +/// +/// The features for this type follow the same pattern as for [``NodeId``] +#[cfg(any(doc, feature = "provenance"))] +#[repr(transparent)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)] +pub struct ProvenanceId(tsk_id_t); + +#[cfg(any(doc, feature = "provenance"))] +impl_id_traits!(ProvenanceId); + /// Handles return codes from low-level tskit functions. /// /// When an error from the tskit C API is detected, diff --git a/src/migration_table.rs b/src/migration_table.rs index 2f273b26b..a2a0874cf 100644 --- a/src/migration_table.rs +++ b/src/migration_table.rs @@ -1,15 +1,16 @@ use crate::bindings as ll_bindings; use crate::metadata; use crate::{tsk_id_t, TskitError}; +use crate::{MigrationId, NodeId, PopulationId}; /// Row of a [`MigrationTable`] pub struct MigrationTableRow { - pub id: tsk_id_t, + pub id: MigrationId, pub left: f64, pub right: f64, - pub node: tsk_id_t, - pub source: tsk_id_t, - pub dest: tsk_id_t, + pub node: NodeId, + pub source: PopulationId, + pub dest: PopulationId, pub time: f64, pub metadata: Option>, } @@ -30,7 +31,7 @@ impl PartialEq for MigrationTableRow { fn make_migration_table_row(table: &MigrationTable, pos: tsk_id_t) -> Option { if pos < table.num_rows() as tsk_id_t { Some(MigrationTableRow { - id: pos, + id: pos.into(), left: table.left(pos).unwrap(), right: table.right(pos).unwrap(), node: table.node(pos).unwrap(), @@ -92,8 +93,8 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn left(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.left) + pub fn left + Copy>(&'a self, row: M) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left) } /// Return the right coordinate for a given row. @@ -101,8 +102,8 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn right(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.right) + pub fn right + Copy>(&'a self, row: M) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.right) } /// Return the node for a given row. @@ -110,8 +111,8 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn node(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.source) + pub fn node + Copy>(&'a self, row: M) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.source, NodeId) } /// Return the source population for a given row. @@ -119,8 +120,17 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn source(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node) + pub fn source + Copy>( + &'a self, + row: M, + ) -> Result { + unsafe_tsk_column_access!( + row.into().0, + 0, + self.num_rows(), + self.table_.node, + PopulationId + ) } /// Return the destination population for a given row. @@ -128,8 +138,14 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn dest(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.dest) + pub fn dest + Copy>(&'a self, row: M) -> Result { + unsafe_tsk_column_access!( + row.into().0, + 0, + self.num_rows(), + self.table_.dest, + PopulationId + ) } /// Return the time of the migration event for a given row. @@ -137,8 +153,8 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// * [`TskitError::IndexError`] if `row` is out of range. - pub fn time(&'a self, row: tsk_id_t) -> Result { - unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time) + pub fn time + Copy>(&'a self, row: M) -> Result { + unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time) } /// Return the metadata for a given row. @@ -148,9 +164,9 @@ impl<'a> MigrationTable<'a> { /// * [`TskitError::IndexError`] if `row` is out of range. pub fn metadata( &'a self, - row: tsk_id_t, + row: MigrationId, ) -> Result, TskitError> { - let buffer = metadata_to_vector!(self, row)?; + let buffer = metadata_to_vector!(self, row.0)?; decode_metadata_row!(T, buffer) } @@ -169,7 +185,7 @@ impl<'a> MigrationTable<'a> { /// # Errors /// /// [`TskitError::IndexError`] if `r` is out of range. - pub fn row(&self, r: tsk_id_t) -> Result { - table_row_access!(r, self, make_migration_table_row) + pub fn row + Copy>(&self, r: M) -> Result { + table_row_access!(r.into().0, self, make_migration_table_row) } } diff --git a/src/provenance.rs b/src/provenance.rs index db07ecc1e..91e45f74e 100644 --- a/src/provenance.rs +++ b/src/provenance.rs @@ -11,7 +11,7 @@ //! See [`Provenance`] for examples. use crate::bindings as ll_bindings; -use crate::{tsk_id_t, tsk_size_t, TskitError}; +use crate::{tsk_id_t, tsk_size_t, ProvenanceId, TskitError}; /// Enable provenance table access. /// @@ -100,7 +100,7 @@ pub trait Provenance: crate::TableAccess { /// # Parameters /// /// * `record`: the provenance record - fn add_provenance(&mut self, record: &str) -> crate::TskReturnValue; + fn add_provenance(&mut self, record: &str) -> Result; /// Return an immutable reference to the table, type [`ProvenanceTable`] fn provenances(&self) -> ProvenanceTable; /// Return an iterator over the rows of the [`ProvenanceTable`]. @@ -114,7 +114,7 @@ pub trait Provenance: crate::TableAccess { /// Row of a [`ProvenanceTable`]. pub struct ProvenanceTableRow { /// The row id - pub id: tsk_id_t, + pub id: ProvenanceId, /// ISO-formatted time stamp pub timestamp: String, /// The provenance record @@ -127,6 +127,12 @@ impl PartialEq for ProvenanceTableRow { } } +impl std::fmt::Display for ProvenanceId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "ProvenanceId({})", self.0) + } +} + impl std::fmt::Display for ProvenanceTableRow { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!( @@ -140,7 +146,7 @@ impl std::fmt::Display for ProvenanceTableRow { fn make_provenance_table_row(table: &ProvenanceTable, pos: tsk_id_t) -> Option { if pos < table.num_rows() as tsk_id_t { Some(ProvenanceTableRow { - id: pos, + id: pos.into(), timestamp: table.timestamp(pos).unwrap(), record: table.record(pos).unwrap(), }) @@ -203,9 +209,9 @@ impl<'a> ProvenanceTable<'a> { /// # Errors /// /// [`TskitError::IndexError`] if `r` is out of range. - pub fn timestamp(&'a self, row: tsk_id_t) -> Result { + pub fn timestamp + Copy>(&'a self, row: P) -> Result { match unsafe_tsk_ragged_char_column_access!( - row, + row.into().0, 0, self.num_rows(), self.table_.timestamp, @@ -226,9 +232,9 @@ impl<'a> ProvenanceTable<'a> { /// # Errors /// /// [`TskitError::IndexError`] if `r` is out of range. - pub fn record(&'a self, row: tsk_id_t) -> Result { + pub fn record + Copy>(&'a self, row: P) -> Result { match unsafe_tsk_ragged_char_column_access!( - row, + row.into().0, 0, self.num_rows(), self.table_.record, @@ -249,11 +255,14 @@ impl<'a> ProvenanceTable<'a> { /// # Errors /// /// [`TskitError::IndexError`] if `r` is out of range. - pub fn row(&'a self, row: tsk_id_t) -> Result { - if row < 0 { + pub fn row + Copy>( + &'a self, + row: P, + ) -> Result { + if row.into() < 0 { Err(TskitError::IndexError) } else { - match make_provenance_table_row(self, row) { + match make_provenance_table_row(self, row.into().0) { Some(x) => Ok(x), None => Err(TskitError::IndexError), } diff --git a/src/table_collection.rs b/src/table_collection.rs index 6df1ebf19..03997e80c 100644 --- a/src/table_collection.rs +++ b/src/table_collection.rs @@ -633,7 +633,7 @@ impl crate::traits::NodeListGenerator for TableCollection {} #[cfg(any(doc, feature = "provenance"))] impl crate::provenance::Provenance for TableCollection { - fn add_provenance(&mut self, record: &str) -> TskReturnValue { + fn add_provenance(&mut self, record: &str) -> Result { if record.is_empty() { return Err(TskitError::ValueError { got: String::from("empty string slice"), @@ -650,7 +650,7 @@ impl crate::provenance::Provenance for TableCollection { record.len() as tsk_size_t, ) }; - handle_tsk_return_value!(rv) + handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv)) } fn provenances(&self) -> crate::provenance::ProvenanceTable { diff --git a/src/trees.rs b/src/trees.rs index b6c88a618..6f1cf169b 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -1135,7 +1135,7 @@ impl crate::traits::NodeListGenerator for TreeSequence {} #[cfg(any(doc, feature = "provenance"))] impl crate::provenance::Provenance for TreeSequence { - fn add_provenance(&mut self, record: &str) -> TskReturnValue { + fn add_provenance(&mut self, record: &str) -> Result { if record.is_empty() { return Err(TskitError::ValueError { got: String::from("empty string slice"), @@ -1152,7 +1152,7 @@ impl crate::provenance::Provenance for TreeSequence { record.len() as tsk_size_t, ) }; - handle_tsk_return_value!(rv) + handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv)) } fn provenances(&self) -> crate::provenance::ProvenanceTable {