From 6a14d66eb953cf16c6bf69df27d40379afa9d941 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Fri, 20 Sep 2019 20:28:19 +0900 Subject: [PATCH] Add project_ref method and #[project_ref] attribute --- ci/azure-test.yml | 8 +- examples/enum-default-expanded.rs | 16 ++ examples/pinned_drop-expanded.rs | 17 ++ examples/struct-default-expanded.rs | 16 ++ examples/unsafe_unpin-expanded.rs | 16 ++ pin-project-internal/src/lib.rs | 36 ++++- .../src/pin_project/attribute.rs | 117 +++++++++++--- pin-project-internal/src/project.rs | 66 +++++--- pin-project-internal/src/utils.rs | 16 +- src/lib.rs | 3 + tests/pin_project.rs | 29 +++- tests/project_ref.rs | 151 ++++++++++++++++++ 12 files changed, 428 insertions(+), 63 deletions(-) create mode 100644 tests/project_ref.rs diff --git a/ci/azure-test.yml b/ci/azure-test.yml index 82ebe379..eb8f1dd6 100644 --- a/ci/azure-test.yml +++ b/ci/azure-test.yml @@ -27,8 +27,12 @@ jobs: - ${{ if eq(parameters.toolchain, 'nightly') }}: - script: | - RUSTFLAGS='-Dwarnings -Zallow-features=proc_macro_hygiene,stmt_expr_attributes' cargo ${{ parameters.cmd }} --all --all-features - displayName: cargo ${{ parameters.cmd }} -Zallow-features + RUSTFLAGS='-Dwarnings --cfg pin_project_show_unpin_struct' cargo ${{ parameters.cmd }} --all --all-features + displayName: cargo ${{ parameters.cmd }} --cfg pin_project_show_unpin_struct + + - script: | + RUSTFLAGS='-Dwarnings -Zallow-features=proc_macro_hygiene,stmt_expr_attributes' cargo check --all --all-features + displayName: cargo check -Zallow-features # Refs: https://github.com/rust-lang/cargo/issues/5657 - script: | diff --git a/examples/enum-default-expanded.rs b/examples/enum-default-expanded.rs index 56bf694b..0033c718 100644 --- a/examples/enum-default-expanded.rs +++ b/examples/enum-default-expanded.rs @@ -30,6 +30,12 @@ enum __EnumProjection<'_pin, T, U> { Unpinned(&'_pin mut U), } +#[allow(dead_code)] // This lint warns unused fields/variants. +enum __EnumProjectionRef<'_pin, T, U> { + Pinned(::core::pin::Pin<&'_pin T>), + Unpinned(&'_pin U), +} + impl Enum { fn project<'_pin>(self: ::core::pin::Pin<&'_pin mut Self>) -> __EnumProjection<'_pin, T, U> { unsafe { @@ -39,6 +45,16 @@ impl Enum { } } } + fn project_ref<'_pin>(self: ::core::pin::Pin<&'_pin Self>) -> __EnumProjectionRef<'_pin, T, U> { + unsafe { + match self.get_ref() { + Enum::Pinned(_x0) => { + __EnumProjectionRef::Pinned(::core::pin::Pin::new_unchecked(_x0)) + } + Enum::Unpinned(_x0) => __EnumProjectionRef::Unpinned(_x0), + } + } + } } // Automatically create the appropriate conditional `Unpin` implementation. diff --git a/examples/pinned_drop-expanded.rs b/examples/pinned_drop-expanded.rs index 451af1de..593069c5 100644 --- a/examples/pinned_drop-expanded.rs +++ b/examples/pinned_drop-expanded.rs @@ -38,6 +38,12 @@ pub(crate) struct __FooProjection<'_pin, 'a, T> { field: ::core::pin::Pin<&'_pin mut T>, } +#[allow(dead_code)] +pub(crate) struct __FooProjectionRef<'_pin, 'a, T> { + was_dropped: &'_pin &'a mut bool, + field: ::core::pin::Pin<&'_pin T>, +} + impl<'a, T> Foo<'a, T> { pub(crate) fn project<'_pin>( self: ::core::pin::Pin<&'_pin mut Self>, @@ -50,6 +56,17 @@ impl<'a, T> Foo<'a, T> { } } } + pub(crate) fn project_ref<'_pin>( + self: ::core::pin::Pin<&'_pin Self>, + ) -> __FooProjectionRef<'_pin, 'a, T> { + unsafe { + let Foo { was_dropped, field } = self.get_ref(); + __FooProjectionRef { + was_dropped: was_dropped, + field: ::core::pin::Pin::new_unchecked(field), + } + } + } } #[allow(single_use_lifetimes)] diff --git a/examples/struct-default-expanded.rs b/examples/struct-default-expanded.rs index dc6b9d1d..748b1420 100644 --- a/examples/struct-default-expanded.rs +++ b/examples/struct-default-expanded.rs @@ -30,6 +30,11 @@ struct __StructProjection<'_pin, T, U> { pinned: ::core::pin::Pin<&'_pin mut T>, unpinned: &'_pin mut U, } +#[allow(dead_code)] // This lint warns unused fields/variants. +struct __StructProjectionRef<'_pin, T, U> { + pinned: ::core::pin::Pin<&'_pin T>, + unpinned: &'_pin U, +} impl Struct { fn project<'_pin>(self: ::core::pin::Pin<&'_pin mut Self>) -> __StructProjection<'_pin, T, U> { @@ -41,6 +46,17 @@ impl Struct { } } } + fn project_ref<'_pin>( + self: ::core::pin::Pin<&'_pin Self>, + ) -> __StructProjectionRef<'_pin, T, U> { + unsafe { + let Struct { pinned, unpinned } = self.get_ref(); + __StructProjectionRef { + pinned: ::core::pin::Pin::new_unchecked(pinned), + unpinned: unpinned, + } + } + } } // Automatically create the appropriate conditional `Unpin` implementation. diff --git a/examples/unsafe_unpin-expanded.rs b/examples/unsafe_unpin-expanded.rs index 145056c6..ac09dad3 100644 --- a/examples/unsafe_unpin-expanded.rs +++ b/examples/unsafe_unpin-expanded.rs @@ -32,6 +32,11 @@ pub(crate) struct __FooProjection<'_pin, T, U> { pinned: ::core::pin::Pin<&'_pin mut T>, unpinned: &'_pin mut U, } +#[allow(dead_code)] +pub(crate) struct __FooProjectionRef<'_pin, T, U> { + pinned: ::core::pin::Pin<&'_pin T>, + unpinned: &'_pin U, +} impl Foo { pub(crate) fn project<'_pin>( @@ -42,6 +47,17 @@ impl Foo { __FooProjection { pinned: ::core::pin::Pin::new_unchecked(pinned), unpinned: unpinned } } } + pub(crate) fn project_ref<'_pin>( + self: ::core::pin::Pin<&'_pin Self>, + ) -> __FooProjectionRef<'_pin, T, U> { + unsafe { + let Foo { pinned, unpinned } = self.get_ref(); + __FooProjectionRef { + pinned: ::core::pin::Pin::new_unchecked(pinned), + unpinned: unpinned, + } + } + } } unsafe impl UnsafeUnpin for Foo {} diff --git a/pin-project-internal/src/lib.rs b/pin-project-internal/src/lib.rs index f402f455..2e2e893c 100644 --- a/pin-project-internal/src/lib.rs +++ b/pin-project-internal/src/lib.rs @@ -24,6 +24,8 @@ mod project; use proc_macro::TokenStream; use syn::parse::Nothing; +use utils::{Immutable, Mutable}; + // TODO: Move this doc into pin-project crate when https://github.com/rust-lang/rust/pull/62855 merged. /// An attribute that creates a projection struct covering all the fields. /// @@ -33,13 +35,15 @@ use syn::parse::Nothing; /// the field. /// - For the other fields, makes the unpinned reference to the field. /// -/// The following method is implemented on the original `#[pin_project]` type: +/// The following methods are implemented on the original `#[pin_project]` type: /// /// ``` /// # use std::pin::Pin; -/// # type ProjectedType = (); -/// # trait Projection { -/// fn project(self: Pin<&mut Self>) -> ProjectedType; +/// # type Projection = (); +/// # type ProjectionRef = (); +/// # trait Dox { +/// fn project(self: Pin<&mut Self>) -> Projection; +/// fn project_ref(self: Pin<&Self>) -> ProjectionRef; /// # } /// ``` /// @@ -304,13 +308,14 @@ use syn::parse::Nothing; /// /// Enums without variants (zero-variant enums) are not supported. /// -/// See also [`project`] attribute. +/// See also [`project`] and [`project_ref`] attributes. /// /// [`Pin::as_mut`]: core::pin::Pin::as_mut /// [`Pin::set`]: core::pin::Pin::set /// [`drop`]: Drop::drop /// [`UnsafeUnpin`]: https://docs.rs/pin-project/0.4.0-alpha.11/pin_project/trait.UnsafeUnpin.html /// [`project`]: ./attr.project.html +/// [`project_ref`]: ./attr.project_ref.html /// [`pinned_drop`]: ./attr.pinned_drop.html #[proc_macro_attribute] pub fn pin_project(args: TokenStream, input: TokenStream) -> TokenStream { @@ -361,7 +366,8 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { } // TODO: Move this doc into pin-project crate when https://github.com/rust-lang/rust/pull/62855 merged. -/// An attribute to provide way to refer to the projected type. +/// An attribute to provide way to refer to the projected type returned by +/// `project` method. /// /// The following syntaxes are supported. /// @@ -507,7 +513,23 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { pub fn project(args: TokenStream, input: TokenStream) -> TokenStream { let _: Nothing = syn::parse_macro_input!(args); let input = syn::parse_macro_input!(input); - project::attribute(input).into() + project::attribute(input, Mutable).into() +} + +/// An attribute to provide way to refer to the projected type returned by +/// `project_ref` method. +/// +/// This is the same as [`project`] attribute except it refers to the projected +/// type returned by `project_ref` method. +/// +/// See [`project`] attribute for more details. +/// +/// [`project`]: ./attr.project.html +#[proc_macro_attribute] +pub fn project_ref(args: TokenStream, input: TokenStream) -> TokenStream { + let _: Nothing = syn::parse_macro_input!(args); + let input = syn::parse_macro_input!(input); + project::attribute(input, Immutable).into() } #[doc(hidden)] diff --git a/pin-project-internal/src/pin_project/attribute.rs b/pin-project-internal/src/pin_project/attribute.rs index 30a369e6..41271771 100644 --- a/pin-project-internal/src/pin_project/attribute.rs +++ b/pin-project-internal/src/pin_project/attribute.rs @@ -7,8 +7,8 @@ use syn::{ }; use crate::utils::{ - self, collect_cfg, crate_path, determine_visibility, proj_ident, proj_lifetime_name, VecExt, - DEFAULT_LIFETIME_NAME, + self, collect_cfg, crate_path, determine_visibility, proj_ident, proj_lifetime_name, Immutable, + Mutable, VecExt, DEFAULT_LIFETIME_NAME, }; use super::PIN; @@ -94,9 +94,12 @@ struct Context { /// Name of the original type. orig_ident: Ident, - /// Name of the projected type. + /// Name of the projected type returned by `project` method. proj_ident: Ident, + /// Name of the projected type returned by `project_ref` method. + proj_ref_ident: Ident, + /// Visibility of the original type. vis: Visibility, @@ -137,7 +140,8 @@ impl Context { Ok(Self { crate_path, orig_ident: orig_ident.clone(), - proj_ident: proj_ident(orig_ident), + proj_ident: proj_ident(orig_ident, Mutable), + proj_ref_ident: proj_ident(orig_ident, Immutable), vis: determine_visibility(vis), generics: generics.clone(), lifetime, @@ -263,8 +267,8 @@ impl Context { } /// Creates an implementation of the projection method. - fn make_proj_impl(&self, proj_body: &TokenStream) -> TokenStream { - let Context { proj_ident, orig_ident, vis, lifetime, .. } = self; + fn make_proj_impl(&self, proj_body: &TokenStream, proj_ref_body: &TokenStream) -> TokenStream { + let Context { orig_ident, proj_ident, proj_ref_ident, vis, lifetime, .. } = self; let proj_generics = self.proj_generics(); let proj_ty_generics = proj_generics.split_for_impl().1; @@ -280,6 +284,13 @@ impl Context { #proj_body } } + #vis fn project_ref<#lifetime>( + self: ::core::pin::Pin<&#lifetime Self>, + ) -> #proj_ref_ident #proj_ty_generics { + unsafe { + #proj_ref_body + } + } } } } @@ -385,13 +396,13 @@ impl Context { fn parse_struct(&self, item: &mut ItemStruct) -> Result { super::validate_struct(&item.ident, &item.fields)?; - let (proj_pat, proj_body, proj_fields) = match &mut item.fields { + let (proj_pat, proj_init, proj_fields, proj_ref_fields) = match &mut item.fields { Fields::Named(fields) => self.visit_named(fields)?, Fields::Unnamed(fields) => self.visit_unnamed(fields, true)?, Fields::Unit => unreachable!(), }; - let Context { orig_ident, proj_ident, vis, .. } = self; + let Context { orig_ident, proj_ident, proj_ref_ident, vis, .. } = self; let proj_generics = self.proj_generics(); let where_clause = item.generics.split_for_impl().2; @@ -399,14 +410,20 @@ impl Context { #[allow(clippy::mut_mut)] // This lint warns `&mut &mut `. #[allow(dead_code)] // This lint warns unused fields/variants. #vis struct #proj_ident #proj_generics #where_clause #proj_fields + #[allow(dead_code)] // This lint warns unused fields/variants. + #vis struct #proj_ref_ident #proj_generics #where_clause #proj_ref_fields }; let proj_body = quote! { let #orig_ident #proj_pat = self.get_unchecked_mut(); - #proj_ident #proj_body + #proj_ident #proj_init + }; + let proj_ref_body = quote! { + let #orig_ident #proj_pat = self.get_ref(); + #proj_ref_ident #proj_init }; - proj_items.extend(self.make_proj_impl(&proj_body)); + proj_items.extend(self.make_proj_impl(&proj_body, &proj_ref_body)); Ok(proj_items) } @@ -414,9 +431,10 @@ impl Context { fn parse_enum(&self, item: &mut ItemEnum) -> Result { super::validate_enum(item.brace_token, &item.variants)?; - let (proj_variants, proj_arms) = self.visit_variants(item)?; + let (proj_variants, proj_ref_variants, proj_arms, proj_ref_arms) = + self.visit_variants(item)?; - let Context { proj_ident, vis, .. } = &self; + let Context { proj_ident, proj_ref_ident, vis, .. } = &self; let proj_generics = self.proj_generics(); let where_clause = item.generics.split_for_impl().2; @@ -426,6 +444,10 @@ impl Context { #vis enum #proj_ident #proj_generics #where_clause { #(#proj_variants,)* } + #[allow(dead_code)] // This lint warns unused fields/variants. + #vis enum #proj_ref_ident #proj_generics #where_clause { + #(#proj_ref_variants,)* + } }; let proj_body = quote! { @@ -433,45 +455,71 @@ impl Context { #(#proj_arms)* } }; + let proj_ref_body = quote! { + match self.get_ref() { + #(#proj_ref_arms)* + } + }; - proj_items.extend(self.make_proj_impl(&proj_body)); + proj_items.extend(self.make_proj_impl(&proj_body, &proj_ref_body)); Ok(proj_items) } - fn visit_variants(&self, item: &mut ItemEnum) -> Result<(Vec, Vec)> { + #[allow(clippy::type_complexity)] + fn visit_variants( + &self, + item: &mut ItemEnum, + ) -> Result<(Vec, Vec, Vec, Vec)> { let mut proj_variants = Vec::with_capacity(item.variants.len()); + let mut proj_ref_variants = Vec::with_capacity(item.variants.len()); let mut proj_arms = Vec::with_capacity(item.variants.len()); + let mut proj_ref_arms = Vec::with_capacity(item.variants.len()); for Variant { attrs, fields, ident, .. } in &mut item.variants { - let (proj_pat, proj_body, proj_fields) = match fields { + let (proj_pat, proj_body, proj_fields, proj_ref_fields) = match fields { Fields::Named(fields) => self.visit_named(fields)?, Fields::Unnamed(fields) => self.visit_unnamed(fields, false)?, - Fields::Unit => (TokenStream::new(), TokenStream::new(), TokenStream::new()), + Fields::Unit => { + (TokenStream::new(), TokenStream::new(), TokenStream::new(), TokenStream::new()) + } }; + let cfg = collect_cfg(attrs); - let Self { orig_ident, proj_ident, .. } = self; + let Self { orig_ident, proj_ident, proj_ref_ident, .. } = self; proj_variants.push(quote! { #(#cfg)* #ident #proj_fields }); + proj_ref_variants.push(quote! { + #(#cfg)* + #ident #proj_ref_fields + }); proj_arms.push(quote! { #(#cfg)* #orig_ident::#ident #proj_pat => { #proj_ident::#ident #proj_body } }); + proj_ref_arms.push(quote! { + #(#cfg)* + #orig_ident::#ident #proj_pat => { + #proj_ref_ident::#ident #proj_body + } + }); } - Ok((proj_variants, proj_arms)) + Ok((proj_variants, proj_ref_variants, proj_arms, proj_ref_arms)) } + #[allow(clippy::cognitive_complexity)] fn visit_named( &self, FieldsNamed { named: fields, .. }: &mut FieldsNamed, - ) -> Result<(TokenStream, TokenStream, TokenStream)> { + ) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> { let mut proj_pat = Vec::with_capacity(fields.len()); let mut proj_body = Vec::with_capacity(fields.len()); let mut proj_fields = Vec::with_capacity(fields.len()); + let mut proj_ref_fields = Vec::with_capacity(fields.len()); for Field { attrs, vis, ident, ty, .. } in fields { let cfg = collect_cfg(attrs); if self.find_pin_attr(attrs)? { @@ -480,6 +528,10 @@ impl Context { #(#cfg)* #vis #ident: ::core::pin::Pin<&#lifetime mut #ty> }); + proj_ref_fields.push(quote! { + #(#cfg)* + #vis #ident: ::core::pin::Pin<&#lifetime #ty> + }); proj_body.push(quote! { #(#cfg)* #ident: ::core::pin::Pin::new_unchecked(#ident) @@ -490,6 +542,10 @@ impl Context { #(#cfg)* #vis #ident: &#lifetime mut #ty }); + proj_ref_fields.push(quote! { + #(#cfg)* + #vis #ident: &#lifetime #ty + }); proj_body.push(quote! { #(#cfg)* #ident: #ident @@ -503,17 +559,20 @@ impl Context { let proj_pat = quote!({ #(#proj_pat),* }); let proj_body = quote!({ #(#proj_body),* }); let proj_fields = quote!({ #(#proj_fields),* }); - Ok((proj_pat, proj_body, proj_fields)) + let proj_ref_fields = quote!({ #(#proj_ref_fields),* }); + + Ok((proj_pat, proj_body, proj_fields, proj_ref_fields)) } fn visit_unnamed( &self, FieldsUnnamed { unnamed: fields, .. }: &mut FieldsUnnamed, is_struct: bool, - ) -> Result<(TokenStream, TokenStream, TokenStream)> { + ) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> { let mut proj_pat = Vec::with_capacity(fields.len()); let mut proj_body = Vec::with_capacity(fields.len()); let mut proj_fields = Vec::with_capacity(fields.len()); + let mut proj_ref_fields = Vec::with_capacity(fields.len()); for (i, Field { attrs, vis, ty, .. }) in fields.iter_mut().enumerate() { let id = format_ident!("_x{}", i); let cfg = collect_cfg(attrs); @@ -529,6 +588,9 @@ impl Context { proj_fields.push(quote! { #vis ::core::pin::Pin<&#lifetime mut #ty> }); + proj_ref_fields.push(quote! { + #vis ::core::pin::Pin<&#lifetime #ty> + }); proj_body.push(quote! { ::core::pin::Pin::new_unchecked(#id) }); @@ -537,6 +599,9 @@ impl Context { proj_fields.push(quote! { #vis &#lifetime mut #ty }); + proj_ref_fields.push(quote! { + #vis &#lifetime #ty + }); proj_body.push(quote! { #id }); @@ -548,8 +613,12 @@ impl Context { let proj_pat = quote!((#(#proj_pat),*)); let proj_body = quote!((#(#proj_body),*)); - let proj_fields = - if is_struct { quote!((#(#proj_fields),*);) } else { quote!((#(#proj_fields),*)) }; - Ok((proj_pat, proj_body, proj_fields)) + let (proj_fields, proj_ref_fields) = if is_struct { + (quote!((#(#proj_fields),*);), quote!((#(#proj_ref_fields),*);)) + } else { + (quote!((#(#proj_fields),*)), quote!((#(#proj_ref_fields),*))) + }; + + Ok((proj_pat, proj_body, proj_fields, proj_ref_fields)) } } diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 000ed863..5193e550 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -6,37 +6,41 @@ use syn::{ *, }; -use crate::utils::{proj_generics, proj_ident, proj_lifetime_name, VecExt, DEFAULT_LIFETIME_NAME}; - -/// The attribute name. -const NAME: &str = "project"; +use crate::utils::{ + proj_generics, proj_ident, proj_lifetime_name, Mutability, Mutable, VecExt, + DEFAULT_LIFETIME_NAME, +}; -pub(crate) fn attribute(input: Stmt) -> TokenStream { - parse(input).unwrap_or_else(|e| e.to_compile_error()) +pub(crate) fn attribute(input: Stmt, mutability: Mutability) -> TokenStream { + parse(input, mutability).unwrap_or_else(|e| e.to_compile_error()) } -fn parse(mut stmt: Stmt) -> Result { +fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { match &mut stmt { Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::default().replace_expr_match(expr) + Context::new(mutability).replace_expr_match(expr) } - Stmt::Local(local) => Context::default().replace_local(local)?, - Stmt::Item(Item::Fn(ItemFn { block, .. })) => Dummy.visit_block_mut(block), - Stmt::Item(Item::Impl(item)) => replace_item_impl(item), - Stmt::Item(Item::Use(item)) => replace_item_use(item)?, + Stmt::Local(local) => Context::new(mutability).replace_local(local)?, + Stmt::Item(Item::Fn(ItemFn { block, .. })) => Dummy { mutability }.visit_block_mut(block), + Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), + Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, _ => {} } Ok(stmt.into_token_stream()) } -#[derive(Default)] struct Context { register: Option<(Ident, usize)>, replaced: bool, + mutability: Mutability, } impl Context { + fn new(mutability: Mutability) -> Self { + Self { register: None, replaced: false, mutability } + } + fn update(&mut self, ident: &Ident, len: usize) { if self.register.is_none() { self.register = Some((ident.clone(), len)); @@ -108,7 +112,7 @@ impl Context { if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) { self.update(&path.segments[0].ident, len); self.replaced = true; - replace_ident(&mut path.segments[0].ident); + replace_ident(&mut path.segments[0].ident, self.mutability); } } } @@ -128,13 +132,13 @@ fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool { } } -fn replace_item_impl(item: &mut ItemImpl) { +fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) { let PathSegment { ident, arguments } = match &mut *item.self_ty { Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(), _ => return, }; - replace_ident(ident); + replace_ident(ident, mutability); let mut lifetime_name = String::from(DEFAULT_LIFETIME_NAME); proj_lifetime_name(&mut lifetime_name, &item.generics.params); @@ -157,20 +161,29 @@ fn replace_item_impl(item: &mut ItemImpl) { } } -fn replace_item_use(item: &mut ItemUse) -> Result<()> { - let mut visitor = UseTreeVisitor { res: Ok(()) }; +fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> { + let mut visitor = UseTreeVisitor { res: Ok(()), mutability }; visitor.visit_item_use_mut(item); visitor.res } -fn replace_ident(ident: &mut Ident) { - *ident = proj_ident(ident); +fn replace_ident(ident: &mut Ident, mutability: Mutability) { + *ident = proj_ident(ident, mutability); } // ================================================================================================= // visitor -struct Dummy; +struct Dummy { + mutability: Mutability, +} + +impl Dummy { + /// Returns the attribute name. + fn name(&self) -> &str { + if self.mutability == Mutable { "project" } else { "project_ref" } + } +} impl VisitMut for Dummy { fn visit_stmt_mut(&mut self, node: &mut Stmt) { @@ -178,19 +191,19 @@ impl VisitMut for Dummy { let attr = match node { Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - expr.attrs.find_remove(NAME) + expr.attrs.find_remove(self.name()) } - Stmt::Local(local) => local.attrs.find_remove(NAME), + Stmt::Local(local) => local.attrs.find_remove(self.name()), _ => return, }; if let Some(attr) = attr { let res = syn::parse2::(attr.tokens).and_then(|_| match node { Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::default().replace_expr_match(expr); + Context::new(self.mutability).replace_expr_match(expr); Ok(()) } - Stmt::Local(local) => Context::default().replace_local(local), + Stmt::Local(local) => Context::new(self.mutability).replace_local(local), _ => unreachable!(), }); @@ -207,6 +220,7 @@ impl VisitMut for Dummy { struct UseTreeVisitor { res: Result<()>, + mutability: Mutability, } impl VisitMut for UseTreeVisitor { @@ -217,7 +231,7 @@ impl VisitMut for UseTreeVisitor { match node { // Desugar `use tree::` into `tree::__Projection`. - UseTree::Name(name) => replace_ident(&mut name.ident), + UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability), UseTree::Glob(glob) => { self.res = Err(error!(glob, "#[project] attribute may not be used on glob imports")); diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index eadcac1c..7b5713a6 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -7,9 +7,21 @@ use syn::{ pub(crate) const DEFAULT_LIFETIME_NAME: &str = "'_pin"; +pub(crate) use Mutability::{Immutable, Mutable}; + +#[derive(Clone, Copy, Eq, PartialEq)] +pub(crate) enum Mutability { + Mutable, + Immutable, +} + /// Creates the ident of projected type from the ident of the original type. -pub(crate) fn proj_ident(ident: &Ident) -> Ident { - format_ident!("__{}Projection", ident) +pub(crate) fn proj_ident(ident: &Ident, mutability: Mutability) -> Ident { + if mutability == Mutable { + format_ident!("__{}Projection", ident) + } else { + format_ident!("__{}ProjectionRef", ident) + } } /// Determines the lifetime names. Ensure it doesn't overlap with any existing lifetime names. diff --git a/src/lib.rs b/src/lib.rs index ba970475..820cc8d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,6 +60,9 @@ pub use pin_project_internal::pinned_drop; #[doc(hidden)] pub use pin_project_internal::project; +#[doc(hidden)] +pub use pin_project_internal::project_ref; + /// A trait used for custom implementations of [`Unpin`]. /// This trait is used in conjunction with the `UnsafeUnpin` /// argument to [`pin_project`] diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 9ddba8dc..108e212a 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -277,6 +277,12 @@ fn lifetime_project() { } impl Struct { + fn get_pin_ref<'a>(self: Pin<&'a Self>) -> Pin<&'a T> { + self.project_ref().pinned + } + fn get_pin_ref_elided(self: Pin<&Self>) -> Pin<&T> { + self.project_ref().pinned + } fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { self.project().pinned } @@ -286,6 +292,12 @@ fn lifetime_project() { } impl<'b, T, U> Struct2<'b, T, U> { + fn get_pin_ref<'a>(self: Pin<&'a Self>) -> Pin<&'a &'b mut T> { + self.project_ref().pinned + } + fn get_pin_ref_elided(self: Pin<&Self>) -> Pin<&&'b mut T> { + self.project_ref().pinned + } fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut &'b mut T> { self.project().pinned } @@ -295,6 +307,16 @@ fn lifetime_project() { } impl Enum { + fn get_pin_ref<'a>(self: Pin<&'a Self>) -> Pin<&'a T> { + match self.project_ref() { + __EnumProjectionRef::Variant { pinned, .. } => pinned, + } + } + fn get_pin_ref_elided(self: Pin<&Self>) -> Pin<&T> { + match self.project_ref() { + __EnumProjectionRef::Variant { pinned, .. } => pinned, + } + } fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { match self.project() { __EnumProjection::Variant { pinned, .. } => pinned, @@ -320,6 +342,9 @@ mod visibility { #[test] fn visibility() { let mut x = visibility::A { b: 0 }; - let x = Pin::new(&mut x).project(); - let _: &mut u8 = x.b; + let x = Pin::new(&mut x); + let y = x.as_ref().project_ref(); + let _: &u8 = y.b; + let y = x.project(); + let _: &mut u8 = y.b; } diff --git a/tests/project_ref.rs b/tests/project_ref.rs new file mode 100644 index 00000000..4b4904be --- /dev/null +++ b/tests/project_ref.rs @@ -0,0 +1,151 @@ +#![warn(unsafe_code)] +#![warn(rust_2018_idioms, single_use_lifetimes)] +#![allow(dead_code)] + +use core::pin::Pin; +use pin_project::{pin_project, project_ref}; + +#[project_ref] // Nightly does not need a dummy attribute to the function. +#[test] +fn project_stmt_expr() { + // struct + + #[pin_project] + struct Foo { + #[pin] + field1: T, + field2: U, + } + + let foo = Foo { field1: 1, field2: 2 }; + + #[project_ref] + let Foo { field1, field2 } = Pin::new(&foo).project_ref(); + + let x: Pin<&i32> = field1; + assert_eq!(*x, 1); + + let y: &i32 = field2; + assert_eq!(*y, 2); + + // tuple struct + + #[pin_project] + struct Bar(#[pin] T, U); + + let bar = Bar(1, 2); + + #[project_ref] + let Bar(x, y) = Pin::new(&bar).project_ref(); + + let x: Pin<&i32> = x; + assert_eq!(*x, 1); + + let y: &i32 = y; + assert_eq!(*y, 2); + + // enum + + #[pin_project] + enum Baz { + Variant1(#[pin] A, B), + Variant2 { + #[pin] + field1: C, + field2: D, + }, + None, + } + + let baz = Baz::Variant1(1, 2); + + let baz = Pin::new(&baz).project_ref(); + + #[project_ref] + match &baz { + Baz::Variant1(x, y) => { + let x: &Pin<&i32> = x; + assert_eq!(**x, 1); + + let y: &&i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &Pin<&i32> = field1; + let _y: &&i32 = field2; + } + Baz::None => {} + } + + #[project_ref] + let val = match &baz { + Baz::Variant1(_, _) => true, + Baz::Variant2 { .. } => false, + Baz::None => false, + }; + assert_eq!(val, true); +} + +#[test] +fn project_impl() { + #[pin_project] + struct HasGenerics { + #[pin] + field1: T, + field2: U, + } + + #[project_ref] + impl HasGenerics { + fn a(self) { + let Self { field1, field2 } = self; + + let _x: Pin<&T> = field1; + let _y: &U = field2; + } + } + + #[pin_project] + struct NoneGenerics { + #[pin] + field1: i32, + field2: u32, + } + + #[project_ref] + impl NoneGenerics {} + + #[pin_project] + struct HasLifetimes<'a, T, U> { + #[pin] + field1: &'a mut T, + field2: U, + } + + #[project_ref] + impl HasLifetimes<'_, T, U> {} + + #[pin_project] + struct HasOverlappingLifetimes<'_pin, T, U> { + #[pin] + field1: &'_pin mut T, + field2: U, + } + + #[allow(single_use_lifetimes)] + #[project_ref] + impl<'_pin, T, U> HasOverlappingLifetimes<'_pin, T, U> {} + + #[pin_project] + struct HasOverlappingLifetimes2 { + #[pin] + field1: T, + field2: U, + } + + #[allow(single_use_lifetimes)] + #[project_ref] + impl HasOverlappingLifetimes2 { + fn foo<'_pin>(&'_pin self) {} + } +}