Skip to content

Commit

Permalink
Add MigrationId and ProvenanceId (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen authored Jul 21, 2021
1 parent 7713ef6 commit 9baff07
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 36 deletions.
23 changes: 23 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,21 @@ pub struct SiteId(tsk_id_t);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct MutationId(tsk_id_t);

/// A migration ID
///
/// This is an integer referring to a row of an [``MigrationTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct MigrationId(tsk_id_t);

impl_id_traits!(NodeId);
impl_id_traits!(IndividualId);
impl_id_traits!(PopulationId);
impl_id_traits!(SiteId);
impl_id_traits!(MutationId);
impl_id_traits!(MigrationId);

// tskit defines this via a type cast
// in a macro. bindgen thus misses it.
Expand Down Expand Up @@ -210,6 +220,19 @@ pub use trees::{NodeTraversalOrder, Tree, TreeSequence};
#[cfg(any(doc, feature = "provenance"))]
pub mod provenance;

/// A provenance ID
///
/// This is an integer referring to a row of an [``ProvenanceTable``].
///
/// The features for this type follow the same pattern as for [``NodeId``]
#[cfg(any(doc, feature = "provenance"))]
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct ProvenanceId(tsk_id_t);

#[cfg(any(doc, feature = "provenance"))]
impl_id_traits!(ProvenanceId);

/// Handles return codes from low-level tskit functions.
///
/// When an error from the tskit C API is detected,
Expand Down
58 changes: 37 additions & 21 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::{tsk_id_t, TskitError};
use crate::{MigrationId, NodeId, PopulationId};

/// Row of a [`MigrationTable`]
pub struct MigrationTableRow {
pub id: tsk_id_t,
pub id: MigrationId,
pub left: f64,
pub right: f64,
pub node: tsk_id_t,
pub source: tsk_id_t,
pub dest: tsk_id_t,
pub node: NodeId,
pub source: PopulationId,
pub dest: PopulationId,
pub time: f64,
pub metadata: Option<Vec<u8>>,
}
Expand All @@ -30,7 +31,7 @@ impl PartialEq for MigrationTableRow {
fn make_migration_table_row(table: &MigrationTable, pos: tsk_id_t) -> Option<MigrationTableRow> {
if pos < table.num_rows() as tsk_id_t {
Some(MigrationTableRow {
id: pos,
id: pos.into(),
left: table.left(pos).unwrap(),
right: table.right(pos).unwrap(),
node: table.node(pos).unwrap(),
Expand Down Expand Up @@ -92,53 +93,68 @@ impl<'a> MigrationTable<'a> {
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn left(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.left)
pub fn left<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left)
}

/// Return the right coordinate for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn right(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.right)
pub fn right<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.right)
}

/// Return the node for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn node(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.source)
pub fn node<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.source, NodeId)
}

/// Return the source population for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn source(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.node)
pub fn source<M: Into<MigrationId> + Copy>(
&'a self,
row: M,
) -> Result<PopulationId, TskitError> {
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_.node,
PopulationId
)
}

/// Return the destination population for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn dest(&'a self, row: tsk_id_t) -> Result<tsk_id_t, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.dest)
pub fn dest<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<PopulationId, TskitError> {
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_.dest,
PopulationId
)
}

/// Return the time of the migration event for a given row.
///
/// # Errors
///
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn time(&'a self, row: tsk_id_t) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row, 0, self.num_rows(), self.table_.time)
pub fn time<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Result<f64, TskitError> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time)
}

/// Return the metadata for a given row.
Expand All @@ -148,9 +164,9 @@ impl<'a> MigrationTable<'a> {
/// * [`TskitError::IndexError`] if `row` is out of range.
pub fn metadata<T: metadata::MetadataRoundtrip>(
&'a self,
row: tsk_id_t,
row: MigrationId,
) -> Result<Option<T>, TskitError> {
let buffer = metadata_to_vector!(self, row)?;
let buffer = metadata_to_vector!(self, row.0)?;
decode_metadata_row!(T, buffer)
}

Expand All @@ -169,7 +185,7 @@ impl<'a> MigrationTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&self, r: tsk_id_t) -> Result<MigrationTableRow, TskitError> {
table_row_access!(r, self, make_migration_table_row)
pub fn row<M: Into<MigrationId> + Copy>(&self, r: M) -> Result<MigrationTableRow, TskitError> {
table_row_access!(r.into().0, self, make_migration_table_row)
}
}
31 changes: 20 additions & 11 deletions src/provenance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//! See [`Provenance`] for examples.
use crate::bindings as ll_bindings;
use crate::{tsk_id_t, tsk_size_t, TskitError};
use crate::{tsk_id_t, tsk_size_t, ProvenanceId, TskitError};

/// Enable provenance table access.
///
Expand Down Expand Up @@ -100,7 +100,7 @@ pub trait Provenance: crate::TableAccess {
/// # Parameters
///
/// * `record`: the provenance record
fn add_provenance(&mut self, record: &str) -> crate::TskReturnValue;
fn add_provenance(&mut self, record: &str) -> Result<ProvenanceId, TskitError>;
/// Return an immutable reference to the table, type [`ProvenanceTable`]
fn provenances(&self) -> ProvenanceTable;
/// Return an iterator over the rows of the [`ProvenanceTable`].
Expand All @@ -114,7 +114,7 @@ pub trait Provenance: crate::TableAccess {
/// Row of a [`ProvenanceTable`].
pub struct ProvenanceTableRow {
/// The row id
pub id: tsk_id_t,
pub id: ProvenanceId,
/// ISO-formatted time stamp
pub timestamp: String,
/// The provenance record
Expand All @@ -127,6 +127,12 @@ impl PartialEq for ProvenanceTableRow {
}
}

impl std::fmt::Display for ProvenanceId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "ProvenanceId({})", self.0)
}
}

impl std::fmt::Display for ProvenanceTableRow {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
Expand All @@ -140,7 +146,7 @@ impl std::fmt::Display for ProvenanceTableRow {
fn make_provenance_table_row(table: &ProvenanceTable, pos: tsk_id_t) -> Option<ProvenanceTableRow> {
if pos < table.num_rows() as tsk_id_t {
Some(ProvenanceTableRow {
id: pos,
id: pos.into(),
timestamp: table.timestamp(pos).unwrap(),
record: table.record(pos).unwrap(),
})
Expand Down Expand Up @@ -203,9 +209,9 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn timestamp(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
pub fn timestamp<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
match unsafe_tsk_ragged_char_column_access!(
row,
row.into().0,
0,
self.num_rows(),
self.table_.timestamp,
Expand All @@ -226,9 +232,9 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn record(&'a self, row: tsk_id_t) -> Result<String, TskitError> {
pub fn record<P: Into<ProvenanceId> + Copy>(&'a self, row: P) -> Result<String, TskitError> {
match unsafe_tsk_ragged_char_column_access!(
row,
row.into().0,
0,
self.num_rows(),
self.table_.record,
Expand All @@ -249,11 +255,14 @@ impl<'a> ProvenanceTable<'a> {
/// # Errors
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row(&'a self, row: tsk_id_t) -> Result<ProvenanceTableRow, TskitError> {
if row < 0 {
pub fn row<P: Into<ProvenanceId> + Copy>(
&'a self,
row: P,
) -> Result<ProvenanceTableRow, TskitError> {
if row.into() < 0 {
Err(TskitError::IndexError)
} else {
match make_provenance_table_row(self, row) {
match make_provenance_table_row(self, row.into().0) {
Some(x) => Ok(x),
None => Err(TskitError::IndexError),
}
Expand Down
4 changes: 2 additions & 2 deletions src/table_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ impl crate::traits::NodeListGenerator for TableCollection {}

#[cfg(any(doc, feature = "provenance"))]
impl crate::provenance::Provenance for TableCollection {
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
if record.is_empty() {
return Err(TskitError::ValueError {
got: String::from("empty string slice"),
Expand All @@ -650,7 +650,7 @@ impl crate::provenance::Provenance for TableCollection {
record.len() as tsk_size_t,
)
};
handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
}

fn provenances(&self) -> crate::provenance::ProvenanceTable {
Expand Down
4 changes: 2 additions & 2 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ impl crate::traits::NodeListGenerator for TreeSequence {}

#[cfg(any(doc, feature = "provenance"))]
impl crate::provenance::Provenance for TreeSequence {
fn add_provenance(&mut self, record: &str) -> TskReturnValue {
fn add_provenance(&mut self, record: &str) -> Result<crate::ProvenanceId, TskitError> {
if record.is_empty() {
return Err(TskitError::ValueError {
got: String::from("empty string slice"),
Expand All @@ -1152,7 +1152,7 @@ impl crate::provenance::Provenance for TreeSequence {
record.len() as tsk_size_t,
)
};
handle_tsk_return_value!(rv)
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
}

fn provenances(&self) -> crate::provenance::ProvenanceTable {
Expand Down

0 comments on commit 9baff07

Please sign in to comment.