Skip to content

Commit

Permalink
Merge pull request #95 from NREL/f2/to_writer
Browse files Browse the repository at this point in the history
implement to_writer method and clean up RustCycle csv serde
  • Loading branch information
calbaker authored Feb 6, 2024
2 parents fba3d1d + 73d3285 commit 707d64e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 66 deletions.
44 changes: 22 additions & 22 deletions rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_resource")]
pub fn from_resource_py(filepath: &PyAny) -> anyhow::Result<Self> {
Self::from_resource(PathBuf::extract(filepath)?)
pub fn from_resource_py(filepath: &PyAny) -> PyResult<Self> {
Self::from_resource(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a file.
Expand All @@ -228,8 +228,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
/// * `filepath`: `str | pathlib.Path` - The filepath at which to write the object
///
#[pyo3(name = "to_file")]
pub fn to_file_py(&self, filepath: &PyAny) -> anyhow::Result<()> {
self.to_file(PathBuf::extract(filepath)?)
pub fn to_file_py(&self, filepath: &PyAny) -> PyResult<()> {
self.to_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a file.
Expand All @@ -241,8 +241,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_file")]
pub fn from_file_py(filepath: &PyAny) -> anyhow::Result<Self> {
Self::from_file(PathBuf::extract(filepath)?)
pub fn from_file_py(filepath: &PyAny) -> PyResult<Self> {
Self::from_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object into a string
Expand All @@ -252,8 +252,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
/// * `format`: `str` - The target format, any of those listed in [`ACCEPTED_STR_FORMATS`](`SerdeAPI::ACCEPTED_STR_FORMATS`)
///
#[pyo3(name = "to_str")]
pub fn to_str_py(&self, format: &str) -> anyhow::Result<String> {
self.to_str(format)
pub fn to_str_py(&self, format: &str) -> PyResult<String> {
self.to_str(format).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a string
Expand All @@ -265,14 +265,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_str")]
pub fn from_str_py(contents: &str, format: &str) -> anyhow::Result<Self> {
Self::from_str(contents, format)
pub fn from_str_py(contents: &str, format: &str) -> PyResult<Self> {
Self::from_str(contents, format).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a JSON string
#[pyo3(name = "to_json")]
pub fn to_json_py(&self) -> anyhow::Result<String> {
self.to_json()
pub fn to_json_py(&self) -> PyResult<String> {
self.to_json().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object to a JSON string
Expand All @@ -283,14 +283,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_json")]
pub fn from_json_py(json_str: &str) -> anyhow::Result<Self> {
Self::from_json(json_str)
pub fn from_json_py(json_str: &str) -> PyResult<Self> {
Self::from_json(json_str).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a YAML string
#[pyo3(name = "to_yaml")]
pub fn to_yaml_py(&self) -> anyhow::Result<String> {
self.to_yaml()
pub fn to_yaml_py(&self) -> PyResult<String> {
self.to_yaml().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a YAML string
Expand All @@ -301,14 +301,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_yaml")]
pub fn from_yaml_py(yaml_str: &str) -> anyhow::Result<Self> {
Self::from_yaml(yaml_str)
pub fn from_yaml_py(yaml_str: &str) -> PyResult<Self> {
Self::from_yaml(yaml_str).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to bincode-encoded `bytes`
#[pyo3(name = "to_bincode")]
pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> anyhow::Result<&'py PyBytes> {
Ok(PyBytes::new(py, &self.to_bincode()?))
pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> PyResult<&'py PyBytes> {
PyResult::Ok(PyBytes::new(py, &self.to_bincode()?)).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from bincode-encoded `bytes`
Expand All @@ -319,8 +319,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_bincode")]
pub fn from_bincode_py(encoded: &PyBytes) -> anyhow::Result<Self> {
Self::from_bincode(encoded.as_bytes())
pub fn from_bincode_py(encoded: &PyBytes) -> PyResult<Self> {
Self::from_bincode(encoded.as_bytes()).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}
});

Expand Down
71 changes: 32 additions & 39 deletions rust/fastsim-core/src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ impl RustCycleCache {
Ok(dict)
}
#[pyo3(name = "to_csv")]
pub fn to_csv_py(&self) -> PyResult<String> {
self.to_csv().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}
#[pyo3(name = "modify_by_const_jerk_trajectory")]
pub fn modify_by_const_jerk_trajectory_py(
&mut self,
Expand Down Expand Up @@ -636,19 +641,25 @@ impl SerdeAPI for RustCycle {
self.init_checks()
}

fn to_file<P: AsRef<Path>>(&self, filepath: P) -> anyhow::Result<()> {
let filepath = filepath.as_ref();
let extension = filepath
.extension()
.and_then(OsStr::to_str)
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?;
match extension.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?,
"json" => serde_json::to_writer(&File::create(filepath)?, self)?,
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
"csv" => self.write_csv(&mut csv::Writer::from_path(filepath)?)?,
fn to_writer<W: std::io::Write>(&self, wtr: W, format: &str) -> anyhow::Result<()> {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(wtr, self)?,
"json" => serde_json::to_writer(wtr, self)?,
"bin" => bincode::serialize_into(wtr, self)?,
"csv" => {
let mut wtr = csv::Writer::from_writer(wtr);
for i in 0..self.len() {
wtr.serialize(RustCycleElement {
time_s: self.time_s[i],
mps: self.mps[i],
grade: Some(self.grade[i]),
road_type: Some(self.road_type[i]),
})?;
}
wtr.flush()?
}
_ => bail!(
"Unsupported format {extension:?}, must be one of {:?}",
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Expand All @@ -660,11 +671,7 @@ impl SerdeAPI for RustCycle {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => self.to_yaml()?,
"json" => self.to_json()?,
"csv" => {
let mut wtr = csv::Writer::from_writer(Vec::with_capacity(self.len()));
self.write_csv(&mut wtr)?;
String::from_utf8(wtr.into_inner()?)?
}
"csv" => self.to_csv()?,
_ => {
bail!(
"Unsupported format {format:?}, must be one of {:?}",
Expand All @@ -682,7 +689,7 @@ impl SerdeAPI for RustCycle {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => Self::from_yaml(contents)?,
"json" => Self::from_json(contents)?,
"csv" => Self::from_csv_str(contents, "".to_string())?,
"csv" => Self::from_reader(contents.as_ref().as_bytes(), "csv")?,
_ => bail!(
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
Expand Down Expand Up @@ -791,37 +798,23 @@ impl RustCycle {
.and_then(OsStr::to_str)
.with_context(|| format!("Could not parse cycle name from filepath: {filepath:?}"))?
.to_string();
let file = File::open(filepath).with_context(|| {
if !filepath.exists() {
format!("File not found: {filepath:?}")
} else {
format!("Could not open file: {filepath:?}")
}
})?;
let mut cyc = Self::from_reader(file, "csv")?;
let mut cyc = Self::from_file(filepath)?;
cyc.name = name;
Ok(cyc)
}

/// Load cycle from CSV string
pub fn from_csv_str<S: AsRef<str>>(csv_str: S, name: String) -> anyhow::Result<Self> {
let mut cyc = Self::from_reader(csv_str.as_ref().as_bytes(), "csv")?;
let mut cyc = Self::from_str(csv_str, "csv")?;
cyc.name = name;
Ok(cyc)
}

/// Write cycle data to a CSV writer
fn write_csv<W: std::io::Write>(&self, wtr: &mut csv::Writer<W>) -> anyhow::Result<()> {
for i in 0..self.len() {
wtr.serialize(RustCycleElement {
time_s: self.time_s[i],
mps: self.mps[i],
grade: Some(self.grade[i]),
road_type: Some(self.road_type[i]),
})?;
}
wtr.flush()?;
Ok(())
/// Write (serialize) cycle to a CSV string
pub fn to_csv(&self) -> anyhow::Result<String> {
let mut buf = Vec::with_capacity(self.len());
self.to_writer(&mut buf, "csv")?;
Ok(String::from_utf8(buf)?)
}

pub fn build_cache(&self) -> RustCycleCache {
Expand Down
14 changes: 9 additions & 5 deletions rust/fastsim-core/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
.extension()
.and_then(OsStr::to_str)
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?;
match extension.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?,
"json" => serde_json::to_writer(&File::create(filepath)?, self)?,
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
self.to_writer(File::create(filepath)?, extension)
}

fn to_writer<W: std::io::Write>(&self, wtr: W, format: &str) -> anyhow::Result<()> {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(wtr, self)?,
"json" => serde_json::to_writer(wtr, self)?,
"bin" => bincode::serialize_into(wtr, self)?,
_ => bail!(
"Unsupported format {extension:?}, must be one of {:?}",
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Expand Down

0 comments on commit 707d64e

Please sign in to comment.