diff --git a/pin-project-internal/src/lib.rs b/pin-project-internal/src/lib.rs index 748758d7..4bdd0ed7 100644 --- a/pin-project-internal/src/lib.rs +++ b/pin-project-internal/src/lib.rs @@ -92,7 +92,7 @@ use syn::parse::Nothing; /// } /// /// impl Foo { -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// let this = self.project(); /// let _: Pin<&mut T> = this.future; // Pinned reference to the field /// let _: &mut U = this.field; // Normal reference to the field @@ -115,7 +115,7 @@ use syn::parse::Nothing; /// } /// /// impl Foo { -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// let this = self.project(); /// let _: Pin<&mut T> = this.future; // Pinned reference to the field /// let _: &mut U = this.field; // Normal reference to the field @@ -162,7 +162,7 @@ use syn::parse::Nothing; /// } /// /// #[pinned_drop] -/// fn my_drop_fn(foo: Pin<&mut Foo>) { +/// fn my_drop_fn(mut foo: Pin<&mut Foo>) { /// let foo = foo.project(); /// println!("Dropping pinned field: {:?}", foo.pinned_field); /// println!("Dropping unpin field: {:?}", foo.unpin_field); @@ -193,7 +193,7 @@ use syn::parse::Nothing; /// } /// /// impl Foo { -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// let this = self.project(); /// let _: Pin<&mut T> = this.future; /// let _: &mut U = this.field; @@ -211,7 +211,7 @@ use syn::parse::Nothing; /// struct Foo(#[pin] T, U); /// /// impl Foo { -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// let this = self.project(); /// let _: Pin<&mut T> = this.0; /// let _: &mut U = this.1; @@ -250,7 +250,7 @@ use syn::parse::Nothing; /// # #[cfg(feature = "project_attr")] /// impl Foo { /// #[project] // Nightly does not need a dummy attribute to the function. -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// #[project] /// match self.project() { /// Foo::Tuple(x, y) => { @@ -347,7 +347,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// /// impl Foo { /// #[project] // Nightly does not need a dummy attribute to the function. -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// #[project] /// let Foo { future, field } = self.project(); /// @@ -372,7 +372,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// /// impl Foo { /// #[project] // Nightly does not need a dummy attribute to the function. -/// fn baz(self: Pin<&mut Self>) { +/// fn baz(mut self: Pin<&mut Self>) { /// #[project] /// match self.project() { /// Foo::Tuple(x, y) => { diff --git a/pin-project-internal/src/pin_project/enums.rs b/pin-project-internal/src/pin_project/enums.rs index 3262c01e..b0a47a27 100644 --- a/pin-project-internal/src/pin_project/enums.rs +++ b/pin-project-internal/src/pin_project/enums.rs @@ -6,7 +6,7 @@ use crate::utils::VecExt; use super::{proj_generics, Context, PIN}; -pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result { +pub(super) fn parse(cx: &mut Context, mut item: ItemEnum) -> Result { if item.variants.is_empty() { return Err(error!(item, "cannot be implemented for enums without variants")); } @@ -23,22 +23,23 @@ pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result return Err(error!(item.variants, "cannot be implemented for enums that have no field")); } - let (proj_variants, proj_arms) = variants(&mut cx, &mut item)?; + let (proj_variants, proj_arms) = variants(cx, &mut item)?; - let impl_drop = cx.impl_drop(&item.generics); + let mut impl_drop = cx.impl_drop(&item.generics); let Context { original, projected, lifetime, impl_unpin, .. } = cx; let proj_generics = proj_generics(&item.generics, &lifetime); let proj_ty_generics = proj_generics.split_for_impl().1; + let proj_trait = &cx.projected_trait; let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); let mut proj_items = quote! { enum #projected #proj_generics #where_clause { #(#proj_variants,)* } }; let proj_method = quote! { - impl #impl_generics #original #ty_generics #where_clause { - fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #projected #proj_ty_generics { + impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #original #ty_generics> #where_clause { + fn project<#lifetime>(&#lifetime mut self) -> #projected #proj_ty_generics #where_clause { unsafe { - match ::core::pin::Pin::get_unchecked_mut(self) { + match self.as_mut().get_unchecked_mut() { #(#proj_arms,)* } } diff --git a/pin-project-internal/src/pin_project/mod.rs b/pin-project-internal/src/pin_project/mod.rs index 64f75944..80fdb902 100644 --- a/pin-project-internal/src/pin_project/mod.rs +++ b/pin-project-internal/src/pin_project/mod.rs @@ -4,11 +4,10 @@ use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, token::Comma, - Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Index, Item, ItemStruct, Lifetime, - LifetimeDef, Meta, NestedMeta, Result, Type, + *, }; -use crate::utils::{crate_path, proj_ident}; +use crate::utils::{crate_path, proj_ident, proj_trait_ident}; mod enums; mod structs; @@ -51,6 +50,10 @@ struct Context { original: Ident, /// Name of the projected type. projected: Ident, + /// Name of the trait generated + /// to provide a 'project' method + projected_trait: Ident, + generics: Generics, lifetime: Lifetime, impl_unpin: ImplUnpin, @@ -63,7 +66,16 @@ impl Context { let projected = proj_ident(&original); let lifetime = proj_lifetime(&generics.params); let impl_unpin = ImplUnpin::new(generics, unsafe_unpin); - Ok(Self { original, projected, lifetime, impl_unpin, pinned_drop }) + let projected_trait = proj_trait_ident(&original); + Ok(Self { + original, + projected, + projected_trait, + lifetime, + impl_unpin, + pinned_drop, + generics: generics.clone(), + }) } fn impl_drop<'a>(&self, generics: &'a Generics) -> ImplDrop<'a> { @@ -74,22 +86,42 @@ impl Context { fn parse(args: TokenStream, input: TokenStream) -> Result { match syn::parse2(input)? { Item::Struct(item) => { - let cx = Context::new(args, item.ident.clone(), &item.generics)?; - let packed_check = ensure_not_packed(&item)?; - let mut res = structs::parse(cx, item)?; - res.extend(packed_check); + let mut cx = Context::new(args, item.ident.clone(), &item.generics)?; + + let mut res = structs::parse(&mut cx, item.clone())?; + res.extend(ensure_not_packed(&item)?); + res.extend(make_proj_trait(&mut cx)?); Ok(res) } Item::Enum(item) => { - let cx = Context::new(args, item.ident.clone(), &item.generics)?; + let mut cx = Context::new(args, item.ident.clone(), &item.generics)?; + // We don't need to check for '#[repr(packed)]', // since it does not apply to enums - enums::parse(cx, item) + let mut res = enums::parse(&mut cx, item.clone())?; + res.extend(make_proj_trait(&mut cx)?); + Ok(res) } item => Err(error!(item, "may only be used on structs or enums")), } } +fn make_proj_trait(cx: &mut Context) -> Result { + let proj_trait = &cx.projected_trait; + let lifetime = &cx.lifetime; + let proj_ident = &cx.projected; + let proj_generics = proj_generics(&cx.generics, &cx.lifetime); + let proj_ty_generics = proj_generics.split_for_impl().1; + + let (orig_generics, _orig_ty_generics, orig_where_clause) = cx.generics.split_for_impl(); + + Ok(quote! { + trait #proj_trait #orig_generics { + fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #orig_where_clause; + } + }) +} + fn ensure_not_packed(item: &ItemStruct) -> Result { for meta in item.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) { if let Meta::List(l) = meta { @@ -220,7 +252,7 @@ impl<'a> ImplDrop<'a> { Self { generics, pinned_drop } } - fn build(self, ident: &Ident) -> TokenStream { + fn build(&mut self, ident: &Ident) -> TokenStream { let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); if let Some(pinned_drop) = self.pinned_drop { @@ -292,7 +324,7 @@ impl ImplUnpin { } /// Creates `Unpin` implementation. - fn build(self, ident: &Ident) -> TokenStream { + fn build(&mut self, ident: &Ident) -> TokenStream { let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); quote! { impl #impl_generics ::core::marker::Unpin for #ident #ty_generics #where_clause {} diff --git a/pin-project-internal/src/pin_project/structs.rs b/pin-project-internal/src/pin_project/structs.rs index 4ed64bc1..4efa6155 100644 --- a/pin-project-internal/src/pin_project/structs.rs +++ b/pin-project-internal/src/pin_project/structs.rs @@ -6,7 +6,7 @@ use crate::utils::VecExt; use super::{proj_generics, Context, PIN}; -pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result { +pub(super) fn parse(cx: &mut Context, mut item: ItemStruct) -> Result { let (proj_fields, proj_init) = match &mut item.fields { Fields::Named(FieldsNamed { named: fields, .. }) | Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) @@ -16,26 +16,27 @@ pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result return Err(error!(item, "cannot be implemented for structs with units")), - Fields::Named(fields) => named(&mut cx, fields)?, - Fields::Unnamed(fields) => unnamed(&mut cx, fields)?, + Fields::Named(fields) => named(cx, fields)?, + Fields::Unnamed(fields) => unnamed(cx, fields)?, }; let orig_ident = &cx.original; let proj_ident = &cx.projected; let lifetime = &cx.lifetime; - let impl_drop = cx.impl_drop(&item.generics); + let mut impl_drop = cx.impl_drop(&item.generics); let proj_generics = proj_generics(&item.generics, &cx.lifetime); let proj_ty_generics = proj_generics.split_for_impl().1; + let proj_trait = &cx.projected_trait; let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); let mut proj_items = quote! { struct #proj_ident #proj_generics #where_clause #proj_fields }; let proj_method = quote! { - impl #impl_generics #orig_ident #ty_generics #where_clause { - fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #proj_ident #proj_ty_generics { + impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause { + fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause { unsafe { - let this = ::core::pin::Pin::get_unchecked_mut(self); + let this = self.as_mut().get_unchecked_mut(); #proj_ident #proj_init } } diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index c411177a..8c482415 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -7,6 +7,10 @@ pub(crate) fn proj_ident(ident: &Ident) -> Ident { format_ident!("__{}Projection", ident) } +pub(crate) fn proj_trait_ident(ident: &Ident) -> Ident { + format_ident!("__{}ProjectionTrait", ident) +} + pub(crate) trait VecExt { fn find_remove(&mut self, ident: &str) -> Option; } diff --git a/src/lib.rs b/src/lib.rs index 2aebc6f9..60357506 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ //! } //! //! impl Foo { -//! fn baz(self: Pin<&mut Self>) { +//! fn baz(mut self: Pin<&mut Self>) { //! let this = self.project(); //! let _: Pin<&mut T> = this.future; // Pinned reference to the field //! let _: &mut U = this.field; // Normal reference to the field diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 2d5d7243..6086c576 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -19,7 +19,8 @@ fn test_pin_project() { let mut foo = Foo { field1: 1, field2: 2 }; - let foo = Pin::new(&mut foo).project(); + let mut foo_orig = Pin::new(&mut foo); + let foo = foo_orig.project(); let x: Pin<&mut i32> = foo.field1; assert_eq!(*x, 1); @@ -27,9 +28,13 @@ fn test_pin_project() { let y: &mut i32 = foo.field2; assert_eq!(*y, 2); + assert_eq!(foo_orig.as_ref().field1, 1); + assert_eq!(foo_orig.as_ref().field2, 2); + let mut foo = Foo { field1: 1, field2: 2 }; - let foo = Pin::new(&mut foo).project(); + let mut foo = Pin::new(&mut foo); + let foo = foo.project(); let __FooProjection { field1, field2 } = foo; let _: Pin<&mut i32> = field1; @@ -42,7 +47,8 @@ fn test_pin_project() { let mut bar = Bar(1, 2); - let bar = Pin::new(&mut bar).project(); + let mut bar = Pin::new(&mut bar); + let bar = bar.project(); let x: Pin<&mut i32> = bar.0; assert_eq!(*x, 1); @@ -53,6 +59,7 @@ fn test_pin_project() { // enum #[pin_project] + #[derive(Eq, PartialEq, Debug)] enum Baz { Variant1(#[pin] A, B), Variant2 { @@ -65,7 +72,8 @@ fn test_pin_project() { let mut baz = Baz::Variant1(1, 2); - let baz = Pin::new(&mut baz).project(); + let mut baz_orig = Pin::new(&mut baz); + let baz = baz_orig.project(); match baz { __BazProjection::Variant1(x, y) => { @@ -82,9 +90,12 @@ fn test_pin_project() { __BazProjection::None => {} } + assert_eq!(Pin::into_ref(baz_orig).get_ref(), &Baz::Variant1(1, 2)); + let mut baz = Baz::Variant2 { field1: 3, field2: 4 }; - let mut baz = Pin::new(&mut baz).project(); + let mut baz = Pin::new(&mut baz); + let mut baz = baz.project(); match &mut baz { __BazProjection::Variant1(x, y) => { @@ -110,6 +121,30 @@ fn test_pin_project() { } } +#[test] +fn enum_project_set() { + #[pin_project] + #[derive(Eq, PartialEq, Debug)] + enum Bar { + Variant1(#[pin] u8), + Variant2(bool), + } + + let mut bar = Bar::Variant1(25); + let mut bar_orig = Pin::new(&mut bar); + let bar_proj = bar_orig.project(); + + match bar_proj { + __BarProjection::Variant1(val) => { + let new_bar = Bar::Variant2(val.as_ref().get_ref() == &25); + bar_orig.set(new_bar); + } + _ => unreachable!(), + } + + assert_eq!(bar, Bar::Variant2(true)); +} + #[test] fn where_clause_and_associated_type_fields() { // struct diff --git a/tests/pinned_drop.rs b/tests/pinned_drop.rs index 6f331da0..f58a2cd4 100644 --- a/tests/pinned_drop.rs +++ b/tests/pinned_drop.rs @@ -14,7 +14,7 @@ pub struct Foo<'a> { } #[pinned_drop] -fn do_drop(foo: Pin<&mut Foo<'_>>) { +fn do_drop(mut foo: Pin<&mut Foo<'_>>) { **foo.project().was_dropped = true; } diff --git a/tests/project.rs b/tests/project.rs index 2adc961a..d61d5626 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -21,9 +21,10 @@ fn test_project_attr() { } let mut foo = Foo { field1: 1, field2: 2 }; + let mut foo = Pin::new(&mut foo); #[project] - let Foo { field1, field2 } = Pin::new(&mut foo).project(); + let Foo { field1, field2 } = foo.project(); let x: Pin<&mut i32> = field1; assert_eq!(*x, 1); @@ -37,9 +38,10 @@ fn test_project_attr() { struct Bar(#[pin] T, U); let mut bar = Bar(1, 2); + let mut bar = Pin::new(&mut bar); #[project] - let Bar(x, y) = Pin::new(&mut bar).project(); + let Bar(x, y) = bar.project(); let x: Pin<&mut i32> = x; assert_eq!(*x, 1); @@ -62,7 +64,8 @@ fn test_project_attr() { let mut baz = Baz::Variant1(1, 2); - let mut baz = Pin::new(&mut baz).project(); + let mut baz = Pin::new(&mut baz); + let mut baz = baz.project(); #[project] match &mut baz { @@ -98,7 +101,8 @@ fn test_project_attr_nightly() { let mut baz = Baz::Variant1(1, 2); - let mut baz = Pin::new(&mut baz).project(); + let mut baz = Pin::new(&mut baz); + let mut baz = baz.project(); #[project] match &mut baz {