From cd4d6f1a4b50d12dfb3bc12676d145df60b34578 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 18 Jan 2024 09:21:26 +0300 Subject: [PATCH] [refactor]: update wasm_codec_derive to use syn2 (#4188) Signed-off-by: VAmuzing --- Cargo.lock | 5 +- wasm_codec/derive/Cargo.toml | 5 +- wasm_codec/derive/src/lib.rs | 380 +++++++++++++++++++++-------------- 3 files changed, 239 insertions(+), 151 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5aecd17db3b..2add67a0266 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2762,11 +2762,12 @@ dependencies = [ name = "iroha_core_wasm_codec_derive" version = "2.0.0-pre-rc.20" dependencies = [ + "iroha_macro_utils", + "manyhow", "once_cell", - "proc-macro-error", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.41", ] [[package]] diff --git a/wasm_codec/derive/Cargo.toml b/wasm_codec/derive/Cargo.toml index 27119faf744..5dc2f1134d7 100644 --- a/wasm_codec/derive/Cargo.toml +++ b/wasm_codec/derive/Cargo.toml @@ -13,8 +13,9 @@ workspace = true proc-macro = true [dependencies] -syn = { workspace = true, features = ["default", "full", "extra-traits", "parsing"] } +syn2 = { workspace = true } quote = { workspace = true } proc-macro2 = { workspace = true } -proc-macro-error = { workspace = true } once_cell = { workspace = true } +manyhow = { workspace = true } +iroha_macro_utils = { workspace = true } diff --git a/wasm_codec/derive/src/lib.rs b/wasm_codec/derive/src/lib.rs index 697c0b906f3..f5886b98ea3 100644 --- a/wasm_codec/derive/src/lib.rs +++ b/wasm_codec/derive/src/lib.rs @@ -4,28 +4,28 @@ use std::ops::Deref; -use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; -use proc_macro_error::{abort, diagnostic, proc_macro_error, Diagnostic, Level, OptionExt as _}; +use iroha_macro_utils::Emitter; +use manyhow::{bail, emit, manyhow, Result}; +use proc_macro2::TokenStream; use quote::quote; -use syn::{parse_quote, punctuated::Punctuated}; +use syn2::{parse_quote, punctuated::Punctuated}; mod kw { - syn::custom_keyword!(state); + syn2::custom_keyword!(state); } struct StateAttr { _state: kw::state, - _equal: syn::Token![=], - ty: syn::Type, + _equal: syn2::Token![=], + ty: syn2::Type, } -impl syn::parse::Parse for StateAttr { - fn parse(input: syn::parse::ParseStream) -> syn::Result { +impl syn2::parse::Parse for StateAttr { + fn parse(input: syn2::parse::ParseStream) -> syn2::Result { let state = input.parse()?; let equal = input.parse()?; - let type_str: syn::LitStr = input.parse()?; - let ty = syn::parse_str(&type_str.value())?; + let type_str: syn2::LitStr = input.parse()?; + let ty = syn2::parse_str(&type_str.value())?; Ok(Self { _state: state, _equal: equal, @@ -56,83 +56,134 @@ impl syn::parse::Parse for StateAttr { /// /// You can pass an attribute in the form of `#[wrap(state = "YourStateType")]`. /// This is needed in cases when it's impossible to infer the state type from the function signature. -#[proc_macro_error] +#[manyhow] #[proc_macro_attribute] pub fn wrap(attr: TokenStream, item: TokenStream) -> TokenStream { - let state_attr_opt = if attr.is_empty() { + let mut emitter = Emitter::new(); + + let state_attr_opt: Option = if attr.is_empty() { None + } else if let Some(v) = emitter.handle(syn2::parse2(attr)) { + Some(v) } else { - Some(syn::parse_macro_input!(attr as StateAttr)) + return emitter.finish_token_stream(); + }; + + let Some(fn_item): Option = emitter.handle(syn2::parse2(item)) else { + return emitter.finish_token_stream(); }; - let mut fn_item = syn::parse_macro_input!(item as syn::ItemFn); + + let parsing_result = impl_wrap_fn(&mut emitter, &state_attr_opt, fn_item); + + if let Some(result) = parsing_result { + emitter.finish_token_stream_with(result) + } else { + emitter.finish_token_stream() + } +} + +fn impl_wrap_fn( + emitter: &mut Emitter, + state_attr_opt: &Option, + mut fn_item: syn2::ItemFn, +) -> Option { let ident = &fn_item.sig.ident; let mut inner_fn_item = fn_item.clone(); - let inner_fn_ident = syn::Ident::new(&format!("__{ident}_inner"), ident.span()); + let inner_fn_ident = syn2::Ident::new(&format!("__{ident}_inner"), ident.span()); inner_fn_item.sig.ident = inner_fn_ident.clone(); - let fn_class = classify_fn(&fn_item.sig); + let fn_class = classify_fn(emitter, &fn_item.sig)?; - fn_item.sig.inputs = gen_params( + let maybe_sig_inputs = gen_params( + emitter, &fn_class, state_attr_opt.as_ref().map(|state_attr| &state_attr.ty), true, ); - let output = gen_output(&fn_class); - fn_item.sig.output = parse_quote! {-> #output}; - - let body = gen_body( + let maybe_body = gen_body( + emitter, &inner_fn_ident, &fn_class, state_attr_opt.as_ref().map(|state_attr| &state_attr.ty), ); + + let (Some(sig_inputs), Some(body)) = (maybe_sig_inputs, maybe_body) else { + return None; + }; + + let output = gen_output(&fn_class); + fn_item.sig.output = parse_quote! {-> #output}; + + fn_item.sig.inputs = sig_inputs; fn_item.block = parse_quote!({#body}); - quote! { + Some(quote! { #inner_fn_item #fn_item - } - .into() + }) } /// Macro to wrap trait function signature with normal parameters and return value /// to another one which will meet `wasmtime` specifications. /// /// See [`wrap`] for more details. -#[proc_macro_error] +#[manyhow] #[proc_macro_attribute] pub fn wrap_trait_fn(attr: TokenStream, item: TokenStream) -> TokenStream { - let state_attr_opt = if attr.is_empty() { + let mut emitter = Emitter::new(); + + let state_attr_opt: Option = if attr.is_empty() { None + } else if let Some(v) = emitter.handle(syn2::parse2(attr)) { + Some(v) } else { - Some(syn::parse_macro_input!(attr as StateAttr)) + return emitter.finish_token_stream(); }; - let mut fn_item = syn::parse_macro_input!(item as syn::TraitItemMethod); + + let Some(fn_item): Option = emitter.handle(syn2::parse2(item)) else { + return emitter.finish_token_stream(); + }; + + let parsing_result = impl_wrap_trait_fn(&mut emitter, &state_attr_opt, fn_item); + + if let Some(result) = parsing_result { + emitter.finish_token_stream_with(result) + } else { + emitter.finish_token_stream() + } +} + +fn impl_wrap_trait_fn( + emitter: &mut Emitter, + state_attr_opt: &Option, + mut fn_item: syn2::TraitItemFn, +) -> Option { let ident = &fn_item.sig.ident; let mut inner_fn_item = fn_item.clone(); - let inner_fn_ident = syn::Ident::new(&format!("__{ident}_inner"), ident.span()); + let inner_fn_ident = syn2::Ident::new(&format!("__{ident}_inner"), ident.span()); inner_fn_item.sig.ident = inner_fn_ident; - let fn_class = classify_fn(&fn_item.sig); + let fn_class = classify_fn(emitter, &fn_item.sig)?; fn_item.sig.inputs = gen_params( + emitter, &fn_class, state_attr_opt.as_ref().map(|state_attr| &state_attr.ty), false, - ); + )?; let output = gen_output(&fn_class); fn_item.sig.output = parse_quote! {-> #output}; - quote! { + Some(quote! { #inner_fn_item #fn_item - } - .into() + }) } /// `with_body` parameter specifies if end function will have a body or not. @@ -140,17 +191,19 @@ pub fn wrap_trait_fn(attr: TokenStream, item: TokenStream) -> TokenStream { /// This is required because /// [patterns are not allowed in functions without body ](/~https://github.com/rust-lang/rust/issues/35203). fn gen_params( + emitter: &mut Emitter, FnClass { param, state: state_ty_from_fn_sig, return_type, }: &FnClass, - state_ty_from_attr: Option<&syn::Type>, + state_ty_from_attr: Option<&syn2::Type>, with_body: bool, -) -> Punctuated { +) -> Option> { let mut params = Punctuated::new(); if state_ty_from_fn_sig.is_some() || param.is_some() || return_type.is_some() { - let state_ty = retrieve_state_ty(state_ty_from_attr, state_ty_from_fn_sig.as_ref()); + let state_ty = + retrieve_state_ty(emitter, state_ty_from_attr, state_ty_from_fn_sig.as_ref())?; let mutability = if with_body { quote! {mut} } else { @@ -170,14 +223,14 @@ fn gen_params( }); } - params + Some(params) } fn gen_output( FnClass { param, return_type, .. }: &FnClass, -) -> syn::Type { +) -> syn2::Type { match (param, return_type) { (None, None) => parse_quote! { () }, (Some(_), None | Some(ReturnType::Result(None, ErrType::WasmtimeError))) => parse_quote! { @@ -189,33 +242,34 @@ fn gen_output( } } -/// [`TokenStream2`] wrapper which will be lazily evaluated +/// [`TokenStream`] wrapper which will be lazily evaluated /// /// Implements [`quote::ToTokens`] trait -struct LazyTokenStream(once_cell::unsync::Lazy); +struct LazyTokenStream(once_cell::unsync::Lazy); -impl TokenStream2> LazyTokenStream { +impl TokenStream> LazyTokenStream { pub fn new(f: F) -> Self { Self(once_cell::unsync::Lazy::new(f)) } } -impl TokenStream2> quote::ToTokens for LazyTokenStream { - fn to_tokens(&self, tokens: &mut TokenStream2) { +impl TokenStream> quote::ToTokens for LazyTokenStream { + fn to_tokens(&self, tokens: &mut TokenStream) { let inner = &*self.0; inner.to_tokens(tokens); } } fn gen_body( - inner_fn_ident: &syn::Ident, + emitter: &mut Emitter, + inner_fn_ident: &syn2::Ident, FnClass { param, state: state_ty_from_fn_sig, return_type, }: &FnClass, - state_ty_from_attr: Option<&syn::Type>, -) -> TokenStream2 { + state_ty_from_attr: Option<&syn2::Type>, +) -> Option { let decode_param = param.as_ref().map_or_else( || quote! {}, |param_ty| quote! { @@ -229,21 +283,23 @@ fn gen_body( quote! {} }; + let memory_state_ty = + retrieve_state_ty(emitter, state_ty_from_attr, state_ty_from_fn_sig.as_ref())?; let get_memory = LazyTokenStream::new(|| { - let state_ty = retrieve_state_ty(state_ty_from_attr, state_ty_from_fn_sig.as_ref()); quote! { - let memory = Runtime::<#state_ty>::get_memory(&mut caller).expect("Checked at instantiation step"); + let memory = Runtime::<#memory_state_ty>::get_memory(&mut caller).expect("Checked at instantiation step"); } }); + let alloc_state_ty = + retrieve_state_ty(emitter, state_ty_from_attr, state_ty_from_fn_sig.as_ref())?; let get_alloc = LazyTokenStream::new(|| { - let state_ty = retrieve_state_ty(state_ty_from_attr, state_ty_from_fn_sig.as_ref()); quote! { - let alloc_fn = Runtime::<#state_ty>::get_alloc_fn(&mut caller).expect("Checked at instantiation step"); + let alloc_fn = Runtime::<#alloc_state_ty>::get_alloc_fn(&mut caller).expect("Checked at instantiation step"); } }); - match (param, return_type) { + let output = match (param, return_type) { // foo() => // foo() // @@ -314,15 +370,17 @@ fn gen_body( ::iroha_wasm_codec::encode_into_memory(&value, &memory, &alloc_fn, &mut caller) } } - } + }; + + Some(output) } /// Classified function struct FnClass { /// Input parameter - param: Option, + param: Option, /// Does function require state explicitly? - state: Option, + state: Option, /// Return type. /// [`None`] means `()` return_type: Option, @@ -331,10 +389,10 @@ struct FnClass { /// Classified return type enum ReturnType { /// [`Result`] type with [`Ok`] and [`Err`] types respectively - Result(Option, ErrType), + Result(Option, ErrType), /// Something other than [`Result`] #[allow(dead_code)] // May be used in future - Other(syn::Type), + Other(syn2::Type), } /// Classified error type @@ -343,182 +401,210 @@ enum ErrType { WasmtimeError, /// Something other than `wasmtime::Error` #[allow(dead_code)] // May be used in future - Other(syn::Type), + Other(syn2::Type), } -fn classify_fn(fn_sig: &syn::Signature) -> FnClass { +fn classify_fn(emitter: &mut Emitter, fn_sig: &syn2::Signature) -> Option { let params = &fn_sig.inputs; - let (param, state) = classify_params_and_state(params); + + // It does not make sense to check further if the next function fails + let (param, state) = classify_params_and_state(emitter, params)?; let output = &fn_sig.output; let output_ty = match output { - syn::ReturnType::Default => { - return FnClass { + syn2::ReturnType::Default => { + return Some(FnClass { param, state, return_type: None, - } + }) } - syn::ReturnType::Type(_, ref ty) => ty, + syn2::ReturnType::Type(_, ref ty) => ty, }; - let output_type_path = unwrap_path(output_ty); - let output_last_segment = last_segment(output_type_path); + let output_type_path = unwrap_path(emitter, output_ty)?; + let output_last_segment = last_segment(emitter, output_type_path)?; if output_last_segment.ident != "Result" { - return FnClass { + return Some(FnClass { param, state, return_type: Some(ReturnType::Other(*output_ty.clone())), - }; + }); } - let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { + let syn2::PathArguments::AngleBracketed(syn2::AngleBracketedGenericArguments { args: generics, .. }) = &output_last_segment.arguments else { - abort!( + emit!( + emitter, output_last_segment.arguments, "`Result` return type should have generic arguments" ); + return None; }; - let ok_type = classify_ok_type(generics); - let err_type = extract_err_type(generics); + let maybe_ok_type = emitter.handle(classify_ok_type(generics)); + + let err_type = extract_err_type(emitter, generics)?; + let err_type_path = unwrap_path(emitter, err_type)?; + let maybe_err_type_last_segment = last_segment(emitter, err_type_path); + + let (Some(ok_type), Some(err_type_last_segment)) = (maybe_ok_type, maybe_err_type_last_segment) + else { + return None; + }; - let err_type_path = unwrap_path(err_type); - let err_type_last_segment = last_segment(err_type_path); let err_type = if err_type_last_segment.ident == "WasmtimeError" { ErrType::WasmtimeError } else { ErrType::Other(err_type.clone()) }; - FnClass { + Some(FnClass { param, state, return_type: Some(ReturnType::Result(ok_type, err_type)), - } + }) } -fn extract_type_from_fn_arg(fn_arg: syn::FnArg) -> syn::PatType { - let syn::FnArg::Typed(pat_type) = fn_arg else { - abort!(fn_arg, "`self` arguments are forbidden"); - }; - - pat_type +fn extract_type_from_fn_arg(emitter: &mut Emitter, fn_arg: syn2::FnArg) -> Option { + if let syn2::FnArg::Typed(pat_type) = fn_arg { + Some(pat_type) + } else { + emit!(emitter, fn_arg, "`self` arguments are forbidden"); + None + } } fn classify_params_and_state( - params: &Punctuated, -) -> (Option, Option) { + emitter: &mut Emitter, + params: &Punctuated, +) -> Option<(Option, Option)> { match params.len() { - 0 => (None, None), + 0 => Some((None, None)), 1 => { let mut params_iter = params.iter(); - let first_param = extract_type_from_fn_arg(params_iter.next().unwrap().clone()); + let first_param = + extract_type_from_fn_arg(emitter, params_iter.next().unwrap().clone())?; if let Ok(state_ty) = parse_state_param(&first_param) { - (None, Some(state_ty.clone())) + Some((None, Some(state_ty.clone()))) } else { - (Some(first_param.ty.deref().clone()), None) + Some((Some(first_param.ty.deref().clone()), None)) } } 2 => { let mut params_iter = params.iter(); - let first_param = extract_type_from_fn_arg(params_iter.next().unwrap().clone()); + let maybe_first_param = + extract_type_from_fn_arg(emitter, params_iter.next().unwrap().clone()); - let second_param = extract_type_from_fn_arg(params_iter.next().unwrap().clone()); - match parse_state_param(&second_param) { - Ok(state_ty) => (Some(first_param.ty.deref().clone()), Some(state_ty.clone())), - Err(diagnostic) => diagnostic.abort(), - } + let second_param = + extract_type_from_fn_arg(emitter, params_iter.next().unwrap().clone())?; + + let state_ty = emitter.handle(parse_state_param(&second_param))?; + + let first_param = maybe_first_param?; + + Some((Some(first_param.ty.deref().clone()), Some(state_ty.clone()))) + } + _ => { + emit!(emitter, params, "No more than 2 parameters are allowed"); + None } - _ => abort!(params, "No more than 2 parameters are allowed"), } } -fn parse_state_param(param: &syn::PatType) -> Result<&syn::Type, Diagnostic> { - let syn::Pat::Ident(pat_ident) = &*param.pat else { - return Err(diagnostic!( - param, - Level::Error, - "State parameter should be an ident" - )); +fn parse_state_param(param: &syn2::PatType) -> Result<&syn2::Type> { + let syn2::Pat::Ident(pat_ident) = &*param.pat else { + bail!(param, "State parameter should be an ident"); }; if !["state", "_state"].contains(&&*pat_ident.ident.to_string()) { - return Err(diagnostic!( - param, - Level::Error, - "State parameter should be named `state` or `_state`" - )); + bail!(param, "State parameter should be named `state` or `_state`"); } - let syn::Type::Reference(ty_ref) = &*param.ty else { - return Err(diagnostic!( + let syn2::Type::Reference(ty_ref) = &*param.ty else { + bail!( param.ty, - Level::Error, "State parameter should be either reference or mutable reference" - )); + ); }; Ok(&*ty_ref.elem) } fn classify_ok_type( - generics: &Punctuated, -) -> Option { - let ok_generic = generics - .first() - .expect_or_abort("First generic argument expected in `Result` return type"); - let syn::GenericArgument::Type(ok_type) = ok_generic else { - abort!( + generics: &Punctuated, +) -> Result> { + let Some(ok_generic) = generics.first() else { + bail!("First generic argument expected in `Result` return type"); + }; + let syn2::GenericArgument::Type(ok_type) = ok_generic else { + bail!( ok_generic, "First generic of `Result` return type expected to be a type" ); }; - if let syn::Type::Tuple(syn::TypeTuple { elems, .. }) = ok_type { - (!elems.is_empty()).then_some(ok_type.clone()) + if let syn2::Type::Tuple(syn2::TypeTuple { elems, .. }) = ok_type { + Ok((!elems.is_empty()).then_some(ok_type.clone())) } else { - Some(ok_type.clone()) + Ok(Some(ok_type.clone())) } } -fn extract_err_type(generics: &Punctuated) -> &syn::Type { - let err_generic = generics - .iter() - .nth(1) - .expect_or_abort("Second generic argument expected in `Result` return type"); - let syn::GenericArgument::Type(err_type) = err_generic else { - abort!( - err_generic, +fn extract_err_type<'arg>( + emitter: &mut Emitter, + generics: &'arg Punctuated, +) -> Option<&'arg syn2::Type> { + let Some(err_generic) = generics.iter().nth(1) else { + emit!( + emitter, "Second generic of `Result` return type expected to be a type" ); + return None; }; - err_type -} -fn unwrap_path(ty: &syn::Type) -> &syn::Path { - let syn::Type::Path(syn::TypePath { ref path, .. }) = *ty else { - abort!(ty, "Expected path"); - }; + if let syn2::GenericArgument::Type(err_type) = err_generic { + Some(err_type) + } else { + emit!( + emitter, + err_generic, + "Second generic of `Result` return type expected to be a type" + ); + None + } +} - path +fn unwrap_path<'ty>(emitter: &mut Emitter, ty: &'ty syn2::Type) -> Option<&'ty syn2::Path> { + if let syn2::Type::Path(syn2::TypePath { ref path, .. }) = ty { + Some(path) + } else { + emit!(emitter, ty, "Expected path"); + None + } } -fn last_segment(path: &syn::Path) -> &syn::PathSegment { - path.segments - .last() - .expect_or_abort("At least one path segment expected") +fn last_segment<'path>( + emitter: &mut Emitter, + path: &'path syn2::Path, +) -> Option<&'path syn2::PathSegment> { + path.segments.last().or_else(|| { + emit!(emitter, "At least one path segment expected"); + None + }) } fn retrieve_state_ty<'ty>( - state_ty_from_attr: Option<&'ty syn::Type>, - state_ty_from_fn_sig: Option<&'ty syn::Type>, -) -> &'ty syn::Type { - state_ty_from_attr - .or(state_ty_from_fn_sig) - .expect_or_abort("`state` attribute is required") + emitter: &mut Emitter, + state_ty_from_attr: Option<&'ty syn2::Type>, + state_ty_from_fn_sig: Option<&'ty syn2::Type>, +) -> Option<&'ty syn2::Type> { + state_ty_from_attr.or(state_ty_from_fn_sig).or_else(|| { + emit!(emitter, "`state` attribute is required"); + None + }) }