diff --git a/rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs b/rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs index 41a66c25..5029ad53 100644 --- a/rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs +++ b/rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs @@ -215,8 +215,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_resource")] - pub fn from_resource_py(filepath: &PyAny) -> anyhow::Result { - Self::from_resource(PathBuf::extract(filepath)?) + pub fn from_resource_py(filepath: &PyAny) -> PyResult { + Self::from_resource(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Write (serialize) an object to a file. @@ -228,8 +228,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// * `filepath`: `str | pathlib.Path` - The filepath at which to write the object /// #[pyo3(name = "to_file")] - pub fn to_file_py(&self, filepath: &PyAny) -> anyhow::Result<()> { - self.to_file(PathBuf::extract(filepath)?) + pub fn to_file_py(&self, filepath: &PyAny) -> PyResult<()> { + self.to_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Read (deserialize) an object from a file. @@ -241,8 +241,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_file")] - pub fn from_file_py(filepath: &PyAny) -> anyhow::Result { - Self::from_file(PathBuf::extract(filepath)?) + pub fn from_file_py(filepath: &PyAny) -> PyResult { + Self::from_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Write (serialize) an object into a string @@ -252,8 +252,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// * `format`: `str` - The target format, any of those listed in [`ACCEPTED_STR_FORMATS`](`SerdeAPI::ACCEPTED_STR_FORMATS`) /// #[pyo3(name = "to_str")] - pub fn to_str_py(&self, format: &str) -> anyhow::Result { - self.to_str(format) + pub fn to_str_py(&self, format: &str) -> PyResult { + self.to_str(format).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Read (deserialize) an object from a string @@ -265,14 +265,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_str")] - pub fn from_str_py(contents: &str, format: &str) -> anyhow::Result { - Self::from_str(contents, format) + pub fn from_str_py(contents: &str, format: &str) -> PyResult { + Self::from_str(contents, format).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Write (serialize) an object to a JSON string #[pyo3(name = "to_json")] - pub fn to_json_py(&self) -> anyhow::Result { - self.to_json() + pub fn to_json_py(&self) -> PyResult { + self.to_json().map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Read (deserialize) an object to a JSON string @@ -283,14 +283,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_json")] - pub fn from_json_py(json_str: &str) -> anyhow::Result { - Self::from_json(json_str) + pub fn from_json_py(json_str: &str) -> PyResult { + Self::from_json(json_str).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Write (serialize) an object to a YAML string #[pyo3(name = "to_yaml")] - pub fn to_yaml_py(&self) -> anyhow::Result { - self.to_yaml() + pub fn to_yaml_py(&self) -> PyResult { + self.to_yaml().map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Read (deserialize) an object from a YAML string @@ -301,14 +301,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_yaml")] - pub fn from_yaml_py(yaml_str: &str) -> anyhow::Result { - Self::from_yaml(yaml_str) + pub fn from_yaml_py(yaml_str: &str) -> PyResult { + Self::from_yaml(yaml_str).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Write (serialize) an object to bincode-encoded `bytes` #[pyo3(name = "to_bincode")] - pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> anyhow::Result<&'py PyBytes> { - Ok(PyBytes::new(py, &self.to_bincode()?)) + pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> PyResult<&'py PyBytes> { + PyResult::Ok(PyBytes::new(py, &self.to_bincode()?)).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } /// Read (deserialize) an object from bincode-encoded `bytes` @@ -319,8 +319,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream { /// #[staticmethod] #[pyo3(name = "from_bincode")] - pub fn from_bincode_py(encoded: &PyBytes) -> anyhow::Result { - Self::from_bincode(encoded.as_bytes()) + pub fn from_bincode_py(encoded: &PyBytes) -> PyResult { + Self::from_bincode(encoded.as_bytes()).map_err(|e| PyIOError::new_err(format!("{:?}", e))) } }); diff --git a/rust/fastsim-core/src/cycle.rs b/rust/fastsim-core/src/cycle.rs index eb9d1979..222065df 100644 --- a/rust/fastsim-core/src/cycle.rs +++ b/rust/fastsim-core/src/cycle.rs @@ -529,6 +529,11 @@ impl RustCycleCache { Ok(dict) } + #[pyo3(name = "to_csv")] + pub fn to_csv_py(&self) -> PyResult { + self.to_csv().map_err(|e| PyIOError::new_err(format!("{:?}", e))) + } + #[pyo3(name = "modify_by_const_jerk_trajectory")] pub fn modify_by_const_jerk_trajectory_py( &mut self, @@ -650,19 +655,25 @@ impl SerdeAPI for RustCycle { Ok(()) } - fn to_file>(&self, filepath: P) -> anyhow::Result<()> { - let filepath = filepath.as_ref(); - let extension = filepath - .extension() - .and_then(OsStr::to_str) - .with_context(|| format!("File extension could not be parsed: {filepath:?}"))?; - match extension.trim_start_matches('.').to_lowercase().as_str() { - "yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?, - "json" => serde_json::to_writer(&File::create(filepath)?, self)?, - "bin" => bincode::serialize_into(&File::create(filepath)?, self)?, - "csv" => self.write_csv(&mut csv::Writer::from_path(filepath)?)?, + fn to_writer(&self, wtr: W, format: &str) -> anyhow::Result<()> { + match format.trim_start_matches('.').to_lowercase().as_str() { + "yaml" | "yml" => serde_yaml::to_writer(wtr, self)?, + "json" => serde_json::to_writer(wtr, self)?, + "bin" => bincode::serialize_into(wtr, self)?, + "csv" => { + let mut wtr = csv::Writer::from_writer(wtr); + for i in 0..self.len() { + wtr.serialize(RustCycleElement { + time_s: self.time_s[i], + mps: self.mps[i], + grade: Some(self.grade[i]), + road_type: Some(self.road_type[i]), + })?; + } + wtr.flush()? + } _ => bail!( - "Unsupported format {extension:?}, must be one of {:?}", + "Unsupported format {format:?}, must be one of {:?}", Self::ACCEPTED_BYTE_FORMATS ), } @@ -674,11 +685,7 @@ impl SerdeAPI for RustCycle { match format.trim_start_matches('.').to_lowercase().as_str() { "yaml" | "yml" => self.to_yaml()?, "json" => self.to_json()?, - "csv" => { - let mut wtr = csv::Writer::from_writer(Vec::with_capacity(self.len())); - self.write_csv(&mut wtr)?; - String::from_utf8(wtr.into_inner()?)? - } + "csv" => self.to_csv()?, _ => { bail!( "Unsupported format {format:?}, must be one of {:?}", @@ -696,7 +703,7 @@ impl SerdeAPI for RustCycle { match format.trim_start_matches('.').to_lowercase().as_str() { "yaml" | "yml" => Self::from_yaml(contents)?, "json" => Self::from_json(contents)?, - "csv" => Self::from_csv_str(contents, "".to_string())?, + "csv" => Self::from_reader(contents.as_ref().as_bytes(), "csv")?, _ => bail!( "Unsupported format {format:?}, must be one of {:?}", Self::ACCEPTED_STR_FORMATS @@ -791,37 +798,23 @@ impl RustCycle { .and_then(OsStr::to_str) .with_context(|| format!("Could not parse cycle name from filepath: {filepath:?}"))? .to_string(); - let file = File::open(filepath).with_context(|| { - if !filepath.exists() { - format!("File not found: {filepath:?}") - } else { - format!("Could not open file: {filepath:?}") - } - })?; - let mut cyc = Self::from_reader(file, "csv")?; + let mut cyc = Self::from_file(filepath)?; cyc.name = name; Ok(cyc) } /// Load cycle from CSV string pub fn from_csv_str>(csv_str: S, name: String) -> anyhow::Result { - let mut cyc = Self::from_reader(csv_str.as_ref().as_bytes(), "csv")?; + let mut cyc = Self::from_str(csv_str, "csv")?; cyc.name = name; Ok(cyc) } - /// Write cycle data to a CSV writer - fn write_csv(&self, wtr: &mut csv::Writer) -> anyhow::Result<()> { - for i in 0..self.len() { - wtr.serialize(RustCycleElement { - time_s: self.time_s[i], - mps: self.mps[i], - grade: Some(self.grade[i]), - road_type: Some(self.road_type[i]), - })?; - } - wtr.flush()?; - Ok(()) + /// Write (serialize) cycle to a CSV string + pub fn to_csv(&self) -> anyhow::Result { + let mut buf = Vec::with_capacity(self.len()); + self.to_writer(&mut buf, "csv")?; + Ok(String::from_utf8(buf)?) } pub fn build_cache(&self) -> RustCycleCache { diff --git a/rust/fastsim-core/src/traits.rs b/rust/fastsim-core/src/traits.rs index e39c1678..0784c36b 100644 --- a/rust/fastsim-core/src/traits.rs +++ b/rust/fastsim-core/src/traits.rs @@ -42,12 +42,16 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> { .extension() .and_then(OsStr::to_str) .with_context(|| format!("File extension could not be parsed: {filepath:?}"))?; - match extension.trim_start_matches('.').to_lowercase().as_str() { - "yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?, - "json" => serde_json::to_writer(&File::create(filepath)?, self)?, - "bin" => bincode::serialize_into(&File::create(filepath)?, self)?, + self.to_writer(File::create(filepath)?, extension) + } + + fn to_writer(&self, wtr: W, format: &str) -> anyhow::Result<()> { + match format.trim_start_matches('.').to_lowercase().as_str() { + "yaml" | "yml" => serde_yaml::to_writer(wtr, self)?, + "json" => serde_json::to_writer(wtr, self)?, + "bin" => bincode::serialize_into(wtr, self)?, _ => bail!( - "Unsupported format {extension:?}, must be one of {:?}", + "Unsupported format {format:?}, must be one of {:?}", Self::ACCEPTED_BYTE_FORMATS ), }