From 4ae66545938d132cf9fc10309abeb871f8513c86 Mon Sep 17 00:00:00 2001 From: Kyle Carow Date: Mon, 29 Jan 2024 14:19:42 -0700 Subject: [PATCH] reorganize init checks for RustCycle --- rust/fastsim-core/src/cycle.rs | 49 ++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/rust/fastsim-core/src/cycle.rs b/rust/fastsim-core/src/cycle.rs index eb9d1979..c1bbbd14 100644 --- a/rust/fastsim-core/src/cycle.rs +++ b/rust/fastsim-core/src/cycle.rs @@ -633,21 +633,7 @@ impl SerdeAPI for RustCycle { const ACCEPTED_STR_FORMATS: &'static [&'static str] = &["yaml", "json", "csv"]; fn init(&mut self) -> anyhow::Result<()> { - ensure!(!self.is_empty(), "Deserialized cycle is empty"); - let cyc_len = self.len(); - ensure!( - self.mps.len() == cyc_len, - "Length of `mps` does not match length of `time_s`" - ); - ensure!( - self.grade.len() == cyc_len, - "Length of `grade` does not match length of `time_s`" - ); - ensure!( - self.road_type.len() == cyc_len, - "Length of `road_type` does not match length of `time_s`" - ); - Ok(()) + self.init_checks() } fn to_file>(&self, filepath: P) -> anyhow::Result<()> { @@ -738,7 +724,7 @@ impl TryFrom>> for RustCycle { let time_s = Array::from_vec( hashmap .get("time_s") - .with_context(|| "`time_s` not in HashMap")? + .with_context(|| format!("`time_s` not in HashMap: {hashmap:?}"))? .to_owned(), ); let cyc_len = time_s.len(); @@ -747,7 +733,7 @@ impl TryFrom>> for RustCycle { mps: Array::from_vec( hashmap .get("mps") - .with_context(|| "`mps` not in HashMap")? + .with_context(|| format!("`mps` not in HashMap: {hashmap:?}"))? .to_owned(), ), grade: Array::from_vec( @@ -783,6 +769,20 @@ impl From for HashMap> { /// pure Rust methods that need to be separate due to pymethods incompatibility impl RustCycle { + fn init_checks(&self) -> anyhow::Result<()> { + ensure!(!self.is_empty(), "Deserialized cycle is empty"); + ensure!(self.is_sorted(), "Deserialized cycle is not sorted in time"); + ensure!( + self.are_fields_equal_length(), + "Deserialized cycle has unequal field lengths\ntime_s: {}\nmps: {}\ngrade: {}\nroad_type: {}", + self.time_s.len(), + self.mps.len(), + self.grade.len(), + self.road_type.len(), + ); + Ok(()) + } + /// Load cycle from CSV file, parsing name from filepath pub fn from_csv_file>(filepath: P) -> anyhow::Result { let filepath = filepath.as_ref(); @@ -853,6 +853,21 @@ impl RustCycle { self.len() == 0 } + pub fn is_sorted(&self) -> bool { + self.time_s + .as_slice() + .unwrap() + .windows(2) + .all(|window| window[0] < window[1]) + } + + pub fn are_fields_equal_length(&self) -> bool { + let cyc_len = self.len(); + [self.mps.len(), self.grade.len(), self.road_type.len()] + .iter() + .all(|len| len == &cyc_len) + } + pub fn test_cyc() -> Self { Self { time_s: Array::range(0.0, 10.0, 1.0),