Skip to content

Commit

Permalink
Allow calling the delegated method through a trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Kobzol committed Jun 29, 2023
1 parent 8cf6434 commit e51dc42
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 17 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@ delegate! {
}
}
```
- Add new `#[unwrap]` expression modifier. Adding it on top of a delegated method will cause the generated
- Add new `#[unwrap]` method modifier. Adding it on top of a delegated method will cause the generated
code to `.unwrap()` the result.
```rust
#[unwrap]
fn foo(&self) -> u32; // foo().unwrap()
```
- Add new `#[through(<trait>)]` method modifier. Adding it on top of a delegated method will cause the generated
code to call the method through the provided trait using [UFCS](https://doc.rust-lang.org/reference/expressions/call-expr.html#disambiguating-function-calls).
```rust
#[through(MyTrait)]
delegate! {
to &self.inner {
#[through(MyTrait)]
fn foo(&self) -> u32; // MyTrait::foo(&self.inner)
}
}
```
- Removed `#[try_into(unwrap)]`. It can now be replaced with the combination of `#[try_into]` and `#[unwrap]`:
```rust
#[try_into]
Expand Down
67 changes: 60 additions & 7 deletions src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::VecDeque;

use syn::parse::ParseStream;
use syn::{Attribute, Error, Meta, TypePath};

Expand Down Expand Up @@ -56,6 +57,26 @@ impl syn::parse::Parse for IntoAttribute {
}
}

pub struct TraitTarget {
type_path: TypePath,
}

impl syn::parse::Parse for TraitTarget {
fn parse(input: ParseStream) -> Result<Self, Error> {
let content;
syn::parenthesized!(content in input);

let type_path: TypePath = content.parse().map_err(|error| {
Error::new(
input.span(),
format!("{error}\nExpected trait path, e.g. #[through(foo::MyTrait)]"),
)
})?;

Ok(TraitTarget { type_path })
}
}

#[derive(Clone)]
pub enum ReturnExpression {
Into(Option<TypePath>),
Expand All @@ -67,6 +88,7 @@ enum ParsedAttribute {
ReturnExpression(ReturnExpression),
Await(bool),
TargetMethod(syn::Ident),
ThroughTrait(TraitTarget),
}

fn parse_attributes(
Expand All @@ -86,12 +108,13 @@ fn parse_attributes(
.unwrap_or_default();
match name.as_str() {
"call" => {
let target =
syn::parse2::<CallMethodAttribute>(attribute.tokens.clone()).unwrap();
let target = syn::parse2::<CallMethodAttribute>(attribute.tokens.clone())
.expect("Cannot parse `call` attribute");
Some(ParsedAttribute::TargetMethod(target.name))
}
"into" => {
let into = syn::parse2::<IntoAttribute>(attribute.tokens.clone()).unwrap();
let into = syn::parse2::<IntoAttribute>(attribute.tokens.clone())
.expect("Cannot parse `into` attribute");
Some(ParsedAttribute::ReturnExpression(ReturnExpression::Into(
into.type_path,
)))
Expand All @@ -118,9 +141,13 @@ fn parse_attributes(
"await" => {
let generate =
syn::parse2::<GenerateAwaitAttribute>(attribute.tokens.clone())
.unwrap();
.expect("Cannot parse `await` attribute");
Some(ParsedAttribute::Await(generate.literal.value))
}
"through" => Some(ParsedAttribute::ThroughTrait(
syn::parse2::<TraitTarget>(attribute.tokens.clone())
.expect("Cannot parse `through` attribute"),
)),
_ => None,
}
} else {
Expand All @@ -141,6 +168,7 @@ pub struct MethodAttributes<'a> {
pub target_method: Option<syn::Ident>,
pub expressions: VecDeque<ReturnExpression>,
pub generate_await: Option<bool>,
pub target_trait: Option<TypePath>,
}

/// Iterates through the attributes of a method and filters special attributes.
Expand All @@ -149,13 +177,15 @@ pub struct MethodAttributes<'a> {
/// - try_into => generates a `try_into()` call after the delegated expression
/// - await => generates an `.await` expression after the delegated expression
/// - unwrap => generates a `unwrap()` call after the delegated expression
/// - throuhg => generates a UFCS call (`Target::method(&<expr>, ...)`) around the delegated expression
pub fn parse_method_attributes<'a>(
attrs: &'a [syn::Attribute],
attrs: &'a [Attribute],
method: &syn::TraitItemMethod,
) -> MethodAttributes<'a> {
let mut target_method: Option<syn::Ident> = None;
let mut expressions: Vec<ReturnExpression> = vec![];
let mut generate_await: Option<bool> = None;
let mut target_trait: Option<TraitTarget> = None;

let (parsed, other) = parse_attributes(attrs);
for attr in parsed {
Expand All @@ -179,6 +209,15 @@ pub fn parse_method_attributes<'a>(
}
target_method = Some(target);
}
ParsedAttribute::ThroughTrait(target) => {
if target_trait.is_some() {
panic!(
"Multiple through attributes specified for {}",
method.sig.ident
)
}
target_trait = Some(target);
}
}
}

Expand All @@ -187,17 +226,20 @@ pub fn parse_method_attributes<'a>(
target_method,
generate_await,
expressions: expressions.into(),
target_trait: target_trait.map(|t| t.type_path),
}
}

pub struct SegmentAttributes {
pub expressions: Vec<ReturnExpression>,
pub generate_await: Option<bool>,
pub target_trait: Option<TypePath>,
}

pub fn parse_segment_attributes(attrs: &[syn::Attribute]) -> SegmentAttributes {
pub fn parse_segment_attributes(attrs: &[Attribute]) -> SegmentAttributes {
let mut expressions: Vec<ReturnExpression> = vec![];
let mut generate_await: Option<bool> = None;
let mut target_trait: Option<TraitTarget> = None;

let (parsed, mut other) = parse_attributes(attrs);
if other.next().is_some() {
Expand All @@ -209,10 +251,16 @@ pub fn parse_segment_attributes(attrs: &[syn::Attribute]) -> SegmentAttributes {
ParsedAttribute::ReturnExpression(expr) => expressions.push(expr),
ParsedAttribute::Await(value) => {
if generate_await.is_some() {
panic!("Multiple `await` attributes specified for segment",)
panic!("Multiple `await` attributes specified for segment");
}
generate_await = Some(value);
}
ParsedAttribute::ThroughTrait(target) => {
if target_trait.is_some() {
panic!("Multiple `through` attributes specified for segment");
}
target_trait = Some(target);
}
ParsedAttribute::TargetMethod(_) => {
panic!("Call attribute cannot be specified on a `to <expr>` segment.");
}
Expand All @@ -221,6 +269,7 @@ pub fn parse_segment_attributes(attrs: &[syn::Attribute]) -> SegmentAttributes {
SegmentAttributes {
expressions,
generate_await,
target_trait: target_trait.map(|t| t.type_path),
}
}

Expand All @@ -233,6 +282,10 @@ pub fn combine_attributes<'a>(
method_attrs.generate_await = segment_attrs.generate_await;
}

if method_attrs.target_trait.is_none() {
method_attrs.target_trait = segment_attrs.target_trait.clone();
}

for expr in &segment_attrs.expressions {
match expr {
ReturnExpression::Into(path) => {
Expand Down
53 changes: 44 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,24 +272,53 @@
//! }
//! }
//! ```
mod attributes;
//! - Specify a trait through which will the delegated method be called
//! (using [UFCS](https://doc.rust-lang.org/reference/expressions/call-expr.html#disambiguating-function-calls).
//! ```rust
//! use delegate::delegate;
//!
//! struct InnerType {}
//! impl InnerType {
//!
//! }
//!
//! trait MyTrait {
//! fn foo(&self);
//! }
//! impl MyTrait for InnerType {
//! fn foo(&self) {}
//! }
//!
//! struct Wrapper(InnerType);
//! impl Wrapper {
//! delegate! {
//! to &self.0 {
//! // Calls `MyTrait::foo(&self.0)`
//! #[through(MyTrait)]
//! pub fn foo(&self);
//! }
//! }
//! }
//! ```
extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::Ident;

use crate::attributes::{
combine_attributes, parse_method_attributes, parse_segment_attributes, ReturnExpression,
SegmentAttributes,
};
use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::parse::ParseStream;
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::{parse_quote, Error, ExprMethodCall, Meta};

use crate::attributes::{
combine_attributes, parse_method_attributes, parse_segment_attributes, ReturnExpression,
SegmentAttributes,
};

mod attributes;

mod kw {
syn::custom_keyword!(to);
syn::custom_keyword!(target);
Expand Down Expand Up @@ -373,7 +402,7 @@ struct DelegatedMethod {
// argument used to call the delegate function: omit receiver, extract an
// identifier from a typed input parameter (and wrap it in an `Expr`).
fn parse_input_into_argument_expression(
function_name: &syn::Ident,
function_name: &Ident,
input: &syn::FnArg,
) -> Option<syn::Expr> {
match input {
Expand Down Expand Up @@ -614,7 +643,11 @@ impl syn::parse::Parse for DelegatedSegment {

let mut methods = vec![];
while !content.is_empty() {
methods.push(content.parse::<DelegatedMethod>().unwrap());
methods.push(
content
.parse::<DelegatedMethod>()
.expect("Cannot parse delegated method"),
);
}

Ok(DelegatedSegment {
Expand Down Expand Up @@ -709,6 +742,8 @@ pub fn delegate(tokens: TokenStream) -> TokenStream {
}
.visit_expr_match_mut(&mut expr_match);
expr_match.into_token_stream()
} else if let Some(target_trait) = attributes.target_trait {
quote::quote! { #target_trait::#name(#delegator_attribute, #(#args),*) }
} else {
quote::quote! { #delegator_attribute.#name(#(#args),*) }
};
Expand Down
42 changes: 42 additions & 0 deletions tests/segment_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,45 @@ fn test_segment_await() {
}
}
}

#[test]
fn test_segment_through_trait() {
trait A {
fn f(&self) -> u32;
}

trait B {
fn f(&self) -> u32;
}

struct Foo;

impl A for Foo {
fn f(&self) -> u32 {
0
}
}
impl B for Foo {
fn f(&self) -> u32 {
1
}
}

struct Bar(Foo);

impl Bar {
delegate! {
#[through(A)]
to &self.0 {
fn f(&self) -> u32;
#[call(f)]
#[through(B)]
fn f2(&self) -> u32;
}
}
}

let bar = Bar(Foo);
assert_eq!(bar.f(), 0);
assert_eq!(bar.f2(), 1);
}
43 changes: 43 additions & 0 deletions tests/through_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use delegate::delegate;

#[test]
fn test_call_through_trait() {
trait A {
fn f(&self) -> u32;
}

trait B {
fn f(&self) -> u32;
}

struct Foo;

impl A for Foo {
fn f(&self) -> u32 {
0
}
}
impl B for Foo {
fn f(&self) -> u32 {
1
}
}

struct Bar(Foo);

impl Bar {
delegate! {
to &self.0 {
#[through(A)]
fn f(&self) -> u32;
#[call(f)]
#[through(B)]
fn f2(&self) -> u32;
}
}
}

let bar = Bar(Foo);
assert_eq!(bar.f(), 0);
assert_eq!(bar.f2(), 1);
}

0 comments on commit e51dc42

Please sign in to comment.