Skip to content

Commit

Permalink
feat: FromStr derive could support setting the error type (#380)
Browse files Browse the repository at this point in the history
* feat: FromStr derive could support setting the error type

ref #91

* chore: adjust unit test

---------

Co-authored-by: Peter Glotfelty <glotfelty.2@osu.edu>
  • Loading branch information
JimChenWYU and Peternator7 authored Nov 25, 2024
1 parent 2fab0ba commit b03e1a0
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 10 deletions.
24 changes: 24 additions & 0 deletions strum_macros/src/helpers/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub mod kw {
custom_keyword!(serialize_all);
custom_keyword!(use_phf);
custom_keyword!(prefix);
custom_keyword!(parse_err_ty);
custom_keyword!(parse_err_fn);

// enum discriminant metadata
custom_keyword!(derive);
Expand Down Expand Up @@ -51,6 +53,14 @@ pub enum EnumMeta {
kw: kw::prefix,
prefix: LitStr,
},
ParseErrTy {
kw: kw::parse_err_ty,
path: Path,
},
ParseErrFn {
kw: kw::parse_err_fn,
path: Path,
},
}

impl Parse for EnumMeta {
Expand Down Expand Up @@ -80,6 +90,20 @@ impl Parse for EnumMeta {
input.parse::<Token![=]>()?;
let prefix = input.parse()?;
Ok(EnumMeta::Prefix { kw, prefix })
} else if lookahead.peek(kw::parse_err_ty) {
let kw = input.parse::<kw::parse_err_ty>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let path = parse2(path_tokens)?;
Ok(EnumMeta::ParseErrTy { kw, path })
} else if lookahead.peek(kw::parse_err_fn) {
let kw = input.parse::<kw::parse_err_fn>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let path = parse2(path_tokens)?;
Ok(EnumMeta::ParseErrFn { kw, path })
} else {
Err(lookahead.error())
}
Expand Down
7 changes: 7 additions & 0 deletions strum_macros/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ use proc_macro2::Span;
use quote::ToTokens;
use syn::spanned::Spanned;

pub fn missing_parse_err_attr_error() -> syn::Error {
syn::Error::new(
Span::call_site(),
"`parse_err_ty` and `parse_err_fn` attribute is both required.",
)
}

pub fn non_enum_error() -> syn::Error {
syn::Error::new(Span::call_site(), "This macro only supports enums.")
}
Expand Down
20 changes: 20 additions & 0 deletions strum_macros/src/helpers/type_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub trait HasTypeProperties {

#[derive(Clone, Default)]
pub struct StrumTypeProperties {
pub parse_err_ty: Option<Path>,
pub parse_err_fn: Option<Path>,
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub crate_module_path: Option<Path>,
Expand All @@ -32,6 +34,8 @@ impl HasTypeProperties for DeriveInput {
let strum_meta = self.get_metadata()?;
let discriminants_meta = self.get_discriminants_metadata()?;

let mut parse_err_ty_kw = None;
let mut parse_err_fn_kw = None;
let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
let mut use_phf_kw = None;
Expand Down Expand Up @@ -82,6 +86,22 @@ impl HasTypeProperties for DeriveInput {
prefix_kw = Some(kw);
output.prefix = Some(prefix);
}
EnumMeta::ParseErrTy { path, kw } => {
if let Some(fst_kw) = parse_err_ty_kw {
return Err(occurrence_error(fst_kw, kw, "parse_err_ty"));
}

parse_err_ty_kw = Some(kw);
output.parse_err_ty = Some(path);
}
EnumMeta::ParseErrFn { path, kw } => {
if let Some(fst_kw) = parse_err_fn_kw {
return Err(occurrence_error(fst_kw, kw, "parse_err_fn"));
}

parse_err_fn_kw = Some(kw);
output.parse_err_fn = Some(path);
}
}
}

Expand Down
37 changes: 27 additions & 10 deletions strum_macros/src/macros/strings/from_string.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields};
use syn::{parse_quote, Data, DeriveInput, Fields, Path};

use crate::helpers::{
non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties,
HasTypeProperties,
missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
HasStrumVariantProperties, HasTypeProperties,
};

pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Expand All @@ -19,9 +19,25 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let strum_module_path = type_properties.crate_module_path();

let mut default_kw = None;
let mut default =
quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };

let (mut default_err_ty, mut default) = match (
type_properties.parse_err_ty,
type_properties.parse_err_fn,
) {
(None, None) => (
quote! { #strum_module_path::ParseError },
quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) },
),
(Some(ty), Some(f)) => {
let ty_path: Path = parse_quote!(#ty);
let fn_path: Path = parse_quote!(#f);

(
quote! { #ty_path },
quote! { ::core::result::Result::Err(#fn_path(s)) },
)
}
_ => return Err(missing_parse_err_attr_error()),
};
let mut phf_exact_match_arms = Vec::new();
let mut standard_match_arms = Vec::new();
for variant in variants {
Expand All @@ -47,6 +63,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}
default_kw = Some(kw);
default_err_ty = quote! { #strum_module_path::ParseError };
default = quote! {
::core::result::Result::Ok(#name::#ident(s.into()))
};
Expand Down Expand Up @@ -146,7 +163,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let from_str = quote! {
#[allow(clippy::use_self)]
impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
type Err = #strum_module_path::ParseError;
type Err = #default_err_ty;

#[inline]
fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
Expand All @@ -160,7 +177,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
&impl_generics,
&ty_generics,
where_clause,
&strum_module_path,
&default_err_ty,
);

Ok(quote! {
Expand All @@ -186,12 +203,12 @@ fn try_from_str(
impl_generics: &syn::ImplGenerics,
ty_generics: &syn::TypeGenerics,
where_clause: Option<&syn::WhereClause>,
strum_module_path: &syn::Path,
default_err_ty: &TokenStream,
) -> TokenStream {
quote! {
#[allow(clippy::use_self)]
impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
type Error = #strum_module_path::ParseError;
type Error = #default_err_ty;

#[inline]
fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
Expand Down
33 changes: 33 additions & 0 deletions strum_tests/tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,36 @@ fn color_default_with_white() {
}
}
}

#[derive(Debug, EnumString)]
#[strum(
parse_err_fn = "some_enum_not_found_err",
parse_err_ty = "CaseCustomParseErrorNotFoundError"
)]
enum CaseCustomParseErrorEnum {
#[strum(serialize = "red")]
Red,
#[strum(serialize = "blue")]
Blue,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
struct CaseCustomParseErrorNotFoundError(String);
impl std::fmt::Display for CaseCustomParseErrorNotFoundError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "not found `{}`", self.0)
}
}
impl std::error::Error for CaseCustomParseErrorNotFoundError {}
fn some_enum_not_found_err(s: &str) -> CaseCustomParseErrorNotFoundError {
CaseCustomParseErrorNotFoundError(s.to_string())
}

#[test]
fn case_custom_parse_error() {
let r = "yellow".parse::<CaseCustomParseErrorEnum>();
assert!(r.is_err());
assert_eq!(
CaseCustomParseErrorNotFoundError("yellow".to_string()),
r.unwrap_err()
);
}

0 comments on commit b03e1a0

Please sign in to comment.