diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 744134aa..8d5eeee9 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -19,9 +19,11 @@ pub(super) fn attribute(input: TokenStream) -> TokenStream { fn parse(input: TokenStream) -> Result { fn replace_stmt(stmt: &mut Stmt) { match stmt { - Stmt::Expr(expr) => expr.replace(&mut Register::default()), + Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { + expr.replace(&mut Register::default()) + } Stmt::Local(local) => local.replace(&mut Register::default()), - Stmt::Item(Item::Fn(item)) => Dummy.visit_item_fn_mut(item), + Stmt::Item(Item::Fn(ItemFn { block, .. })) => Dummy.visit_block_mut(block), _ => {} } } @@ -43,32 +45,9 @@ impl Replace for Local { } } -impl Replace for Expr { +impl Replace for ExprMatch { fn replace(&mut self, register: &mut Register) { - match self { - Expr::ForLoop(ExprForLoop { pat, .. }) => pat.replace(register), - Expr::Let(ExprLet { pats, .. }) => pats.replace(register), - - Expr::Match(ExprMatch { arms, .. }) => { - arms.iter_mut().for_each(|Arm { pats, .. }| pats.replace(register)) - } - - Expr::Block(ExprBlock { block, .. }) | Expr::Unsafe(ExprUnsafe { block, .. }) => { - if let Some(Stmt::Expr(expr)) = block.stmts.last_mut() { - expr.replace(register); - } - } - - Expr::While(ExprWhile { cond: expr, .. }) - | Expr::Type(ExprType { expr, .. }) - | Expr::Paren(ExprParen { expr, .. }) - | Expr::Reference(ExprReference { expr, .. }) => expr.replace(register), - - Expr::Path(ExprPath { qself: None, path, .. }) - | Expr::Struct(ExprStruct { path, .. }) => path.replace(register), - - _ => {} - } + self.arms.iter_mut().for_each(|Arm { pats, .. }| pats.replace(register)) } } @@ -83,7 +62,7 @@ impl Replace for Pat { match self { Pat::Ident(PatIdent { subpat: Some((_, pat)), .. }) | Pat::Ref(PatRef { pat, .. }) - | Pat::Box(PatBox { pat, .. }) => pat.replace(register), + | Pat::Box(PatBox { pat, .. }) => pat.replace(register), // | Pat::Type(PatBox { pat, .. }) // syn 1.0 Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) @@ -96,40 +75,28 @@ impl Replace for Pat { impl Replace for Path { fn replace(&mut self, register: &mut Register) { - fn is_none(args: &PathArguments) -> bool { - match args { - PathArguments::None => true, - _ => false, - } - } - - fn replace_ident(ident: &mut Ident) { - *ident = proj_ident(ident); - } - let len = match self.segments.len() { - // struct - 1 if is_none(&self.segments[0].arguments) => 1, - // enum - 2 if is_none(&self.segments[0].arguments) && is_none(&self.segments[1].arguments) => 2, + // 1: struct + // 2: enum + len @ 1 | len @ 2 => len, // other path _ => return, }; if register.0.is_none() || register.eq(&self.segments[0].ident, len) { register.update(&self.segments[0].ident, len); - replace_ident(&mut self.segments[0].ident); + self.segments[0].ident = proj_ident(&self.segments[0].ident) } } } #[derive(Default)] -struct Register(Option<(String, usize)>); +struct Register(Option<(Ident, usize)>); impl Register { fn update(&mut self, ident: &Ident, len: usize) { if self.0.is_none() { - self.0 = Some((ident.to_string(), len)); + self.0 = Some((ident.clone(), len)); } } @@ -148,97 +115,29 @@ struct Dummy; impl VisitMut for Dummy { fn visit_stmt_mut(&mut self, stmt: &mut Stmt) { - visit_mut::visit_stmt_mut(self, stmt); - visit_stmt_mut(stmt); - } - - // Stop at item bounds - fn visit_item_mut(&mut self, _item: &mut Item) {} -} - -fn visit_stmt_mut(stmt: &mut Stmt) { - fn parse_attr(attrs: &mut A) -> Result<()> { - if let Some(attr) = attrs.find_remove() { - let _: Nothing = syn::parse2(attr.tts)?; - attrs.replace(&mut Register::default()); + macro_rules! parse_attr { + ($this:expr) => {{ + $this.attrs.find_remove(NAME).map_or_else( + || Ok(()), + |attr| { + syn::parse2::(attr.tts) + .map(|_| $this.replace(&mut Register::default())) + }, + ) + }}; } - Ok(()) - } - - if let Err(e) = match stmt { - Stmt::Expr(expr) => parse_attr(expr), - Stmt::Local(local) => parse_attr(local), - _ => return, - } { - *stmt = Stmt::Expr(syn::parse2(e.to_compile_error()).unwrap()) - } -} - -trait AttrsMut { - fn attrs_mut) -> T>(&mut self, f: F) -> T; - - fn find_remove(&mut self) -> Option { - self.attrs_mut(|attrs| attrs.find_remove(NAME)) - } -} -impl AttrsMut for Local { - fn attrs_mut) -> T>(&mut self, f: F) -> T { - f(&mut self.attrs) - } -} + visit_mut::visit_stmt_mut(self, stmt); -macro_rules! attrs_impl { - ($($Expr:ident),*) => { - impl AttrsMut for Expr { - fn attrs_mut) -> T>(&mut self, f: F) -> T { - match self { - $(Expr::$Expr(expr) => f(&mut expr.attrs),)* - Expr::Verbatim(_) => f(&mut Vec::with_capacity(0)), - } - } + if let Err(e) = match stmt { + Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => parse_attr!(expr), + Stmt::Local(local) => parse_attr!(local), + _ => return, + } { + *stmt = Stmt::Expr(syn::parse2(e.to_compile_error()).unwrap()) } - }; -} + } -attrs_impl! { - Box, - InPlace, - Array, - Call, - MethodCall, - Tuple, - Binary, - Unary, - Lit, - Cast, - Type, - Let, - If, - While, - ForLoop, - Loop, - Match, - Closure, - Unsafe, - Block, - Assign, - AssignOp, - Field, - Index, - Range, - Path, - Reference, - Break, - Continue, - Return, - Macro, - Struct, - Repeat, - Paren, - Group, - Try, - Async, - TryBlock, - Yield + // Stop at item bounds + fn visit_item_mut(&mut self, _item: &mut Item) {} } diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index fa35ab21..40b1f46f 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -6,7 +6,7 @@ use syn::{ /// Makes the ident of projected type from the reference of the original ident. pub(crate) fn proj_ident(ident: &Ident) -> Ident { - Ident::new(&format!("__{}Projection", ident), Span::call_site()) + Ident::new(&format!("__{}Projection", ident), ident.span()) } pub(crate) trait VecExt { diff --git a/src/lib.rs b/src/lib.rs index 2d8c4110..28ad9714 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -156,6 +156,9 @@ /// *This attribute is available if pin-project is built with the /// `"project_attr"` feature (it is enabled by default).* /// +/// The attribute at the expression position is not stable, so you need to use +/// a dummy `#[project]` attribute for the function. +/// /// ## Examples /// /// The following two syntaxes are supported. @@ -424,6 +427,9 @@ pub use pin_project_internal::project; /// `pin_project` also supports enums, but to use it ergonomically, you need /// to use the [`project`] attribute. /// +/// The attribute at the expression position is not stable, so you need to use +/// a dummy `#[project]` attribute for the function. +/// /// ```rust /// # #[cfg(feature = "project_attr")] /// use pin_project::{project, pin_project}; diff --git a/tests/compile-test.rs b/tests/compile-test.rs index 8a68893d..c73a9296 100644 --- a/tests/compile-test.rs +++ b/tests/compile-test.rs @@ -13,7 +13,6 @@ fn run_mode(mode: &'static str) { "--edition=2018 \ -Z unstable-options \ --extern pin_project \ - --cfg procmacro2_semver_exempt \ -L {}", me.display() )); diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 4ca66646..ce1ae802 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -28,6 +28,14 @@ fn test_pin_project() { let y: &mut i32 = foo.field2; assert_eq!(*y, 2); + let mut foo = Foo { field1: 1, field2: 2 }; + + let foo = Pin::new(&mut foo).project(); + + let __FooProjection { field1, field2 } = foo; + let _: Pin<&mut i32> = field1; + let _: &mut i32 = field2; + // tuple struct #[pin_project] diff --git a/tests/project.rs b/tests/project.rs index a9c1f151..7db72d59 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -4,6 +4,7 @@ #![warn(rust_2018_idioms)] #![allow(dead_code)] #![cfg(feature = "project_attr")] +#![feature(proc_macro_hygiene, stmt_expr_attributes)] use core::pin::Pin; use pin_project::{pin_project, project}; @@ -80,3 +81,55 @@ fn test_project_attr() { Baz::None => {} } } + +#[test] +fn test_project_attr_nightly() { + // enum + + #[pin_project] + enum Baz { + Variant1(#[pin] A, B), + Variant2 { + #[pin] + field1: C, + field2: D, + }, + None, + } + + let mut baz = Baz::Variant1(1, 2); + + let mut baz = Pin::new(&mut baz).project(); + + #[project] + match &mut baz { + Baz::Variant1(x, y) => { + let x: &mut Pin<&mut i32> = x; + assert_eq!(**x, 1); + + let y: &mut &mut i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &mut Pin<&mut i32> = field1; + let _y: &mut &mut i32 = field2; + } + Baz::None => {} + } + + let () = #[project] + match &mut baz { + Baz::Variant1(x, y) => { + let x: &mut Pin<&mut i32> = x; + assert_eq!(**x, 1); + + let y: &mut &mut i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &mut Pin<&mut i32> = field1; + let _y: &mut &mut i32 = field2; + } + Baz::None => {} + }; +} diff --git a/tests/ui/project/type-mismatch.rs b/tests/ui/project/type-mismatch.rs new file mode 100644 index 00000000..8680cf5f --- /dev/null +++ b/tests/ui/project/type-mismatch.rs @@ -0,0 +1,81 @@ +// compile-fail + +#![deny(warnings, unsafe_code)] +#![feature(proc_macro_hygiene, stmt_expr_attributes)] + +use pin_project::{pin_project, project}; +use std::pin::Pin; + +#[project] +fn span() { + // enum + + #[pin_project] + enum Baz { + Variant1(#[pin] A, B), + Variant2 { + #[pin] + field1: C, + field2: D, + }, + None, + } + + let mut baz = Baz::Variant1(1, 2); + + let mut baz = Pin::new(&mut baz).project(); + + #[project] + match &mut baz { + Baz::Variant1(x, y) => { + let x: &mut Pin<&mut i32> = x; + assert_eq!(**x, 1); + + let y: &mut &mut i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &mut Pin<&mut i32> = field1; + let _y: &mut &mut i32 = field2; + } + None => {} //~ ERROR mismatched types + } +} + +// FIXME: `#[project]` for stmt/expr loses span +fn loses_span() { + // enum + + #[pin_project] + enum Baz { + Variant1(#[pin] A, B), + Variant2 { + #[pin] + field1: C, + field2: D, + }, + None, + } + + let mut baz = Baz::Variant1(1, 2); + + let mut baz = Pin::new(&mut baz).project(); + + #[project] //~ ERROR mismatched types + match &mut baz { + Baz::Variant1(x, y) => { + let x: &mut Pin<&mut i32> = x; + assert_eq!(**x, 1); + + let y: &mut &mut i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &mut Pin<&mut i32> = field1; + let _y: &mut &mut i32 = field2; + } + None => {} + } +} + +fn main() {} diff --git a/tests/ui/project/type-mismatch.stderr b/tests/ui/project/type-mismatch.stderr new file mode 100644 index 00000000..4558d91b --- /dev/null +++ b/tests/ui/project/type-mismatch.stderr @@ -0,0 +1,17 @@ +error[E0308]: mismatched types + --> $DIR/type-mismatch.rs:41:9 + | +41 | None => {} //~ ERROR mismatched types + | ^^^^ expected enum `span::__BazProjection`, found enum `std::option::Option` + | + = note: expected type `span::__BazProjection<'_, {integer}, {integer}, _, _>` + found type `std::option::Option<_>` + +error[E0308]: mismatched types + | + = note: expected type `loses_span::__BazProjection<'_, {integer}, {integer}, _, _>` + found type `std::option::Option<_>` + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0308`.