diff --git a/.gitignore b/.gitignore index d1ed8bf0..273610f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ +.idea +.vscode target/ -.vscode/ **/*.rs.bk Cargo.lock ___* diff --git a/src/lib.rs b/src/lib.rs index ca844d56..3b533c89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,6 +146,8 @@ extern crate proc_macro; #[macro_use] mod utils; +#[cfg(feature = "project_attr")] +mod pin_project; #[cfg(feature = "project_attr")] mod project; mod unsafe_project; @@ -425,3 +427,10 @@ pub fn project(args: TokenStream, input: TokenStream) -> TokenStream { assert!(args.is_empty()); TokenStream::from(project::attribute(input.into())) } + +// TODO: doc +#[cfg(feature = "project_attr")] +#[proc_macro_attribute] +pub fn pin_project(args: TokenStream, input: TokenStream) -> TokenStream { + TokenStream::from(pin_project::attribute(args.into(), input.into())) +} diff --git a/src/pin_project.rs b/src/pin_project.rs new file mode 100644 index 00000000..3e0ec621 --- /dev/null +++ b/src/pin_project.rs @@ -0,0 +1,148 @@ +use proc_macro2::{Group, Ident, TokenStream, TokenTree}; +use quote::ToTokens; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + visit_mut::VisitMut, + *, +}; + +use crate::project::Dummy; + +pub(super) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream { + syn::parse2(input) + .and_then(|mut item| { + syn::parse2(args).map(|args: Args| { + // TODO: Integrate into `replace_item_fn`?. + Dummy.visit_item_fn_mut(&mut item); + replace_item_fn(&args.0, &mut item); + item.into_token_stream() + }) + }) + .unwrap_or_else(|e| e.to_compile_error()) +} + +fn replace_item_fn(args: &[Ident], ItemFn { decl, block, .. }: &mut ItemFn) { + decl.inputs.iter_mut().for_each(|input| match input { + FnArg::Captured(ArgCaptured { + pat: Pat::Ident(pat @ PatIdent { subpat: None, .. }), + .. + }) if args.contains(&pat.ident) => { + let mut local = Local { + attrs: Vec::new(), + let_token: token::Let::default(), + pats: Punctuated::new(), + ty: None, + init: None, + semi_token: token::Semi::default(), + }; + let (local_pat, init) = if pat.ident == "self" { + ReplaceSelf.visit_block_mut(block); + let mut local_pat = pat.clone(); + prepend_underscores_to_self(&mut local_pat.ident); + (local_pat, syn::parse_quote!(self.project())) + } else { + let ident = &pat.ident; + (pat.clone(), syn::parse_quote!(#ident.project())) + }; + local.pats.push(Pat::Ident(local_pat)); + local.init = Some((token::Eq::default(), init)); + block.stmts.insert(0, Stmt::Local(local)); + + if pat.by_ref.is_none() { + pat.mutability = None; + } + pat.by_ref = None; + } + _ => {} + }) +} + +struct Args(Vec); + +impl Parse for Args { + fn parse(input: ParseStream<'_>) -> syn::Result { + let mut args = Vec::new(); + let mut first = true; + while !input.is_empty() { + if first { + first = false; + } else { + let _: Token![,] = input.parse()?; + if input.is_empty() { + break; + } + } + + let ident = if input.peek(Token![self]) { + let t: Token![self] = input.parse()?; + Ident::new("self", t.span) + } else { + input.parse()? + }; + if args.contains(&ident) { + // TODO: error + } else { + args.push(ident); + } + } + Ok(Self(args)) + } +} + +// https://github.com/dtolnay/no-panic/blob/master/src/lib.rs + +struct ReplaceSelf; + +impl VisitMut for ReplaceSelf { + fn visit_expr_path_mut(&mut self, i: &mut ExprPath) { + if i.qself.is_none() && i.path.is_ident("self") { + prepend_underscores_to_self(&mut i.path.segments[0].ident); + } + } + + fn visit_macro_mut(&mut self, i: &mut Macro) { + // We can't tell in general whether `self` inside a macro invocation + // refers to the self in the argument list or a different self + // introduced within the macro. Heuristic: if the macro input contains + // `fn`, then `self` is more likely to refer to something other than the + // outer function's self argument. + if !contains_fn(i.tts.clone()) { + i.tts = fold_token_stream(i.tts.clone()); + } + } + + fn visit_item_mut(&mut self, _i: &mut Item) { + // Do nothing, as `self` now means something else. + } +} + +fn contains_fn(tts: TokenStream) -> bool { + tts.into_iter().any(|tt| match tt { + TokenTree::Ident(ident) => ident == "fn", + TokenTree::Group(group) => contains_fn(group.stream()), + _ => false, + }) +} + +fn fold_token_stream(tts: TokenStream) -> TokenStream { + tts.into_iter() + .map(|tt| match tt { + TokenTree::Ident(mut ident) => { + prepend_underscores_to_self(&mut ident); + TokenTree::Ident(ident) + } + TokenTree::Group(group) => { + let content = fold_token_stream(group.stream()); + TokenTree::Group(Group::new(group.delimiter(), content)) + } + other => other, + }) + .collect() +} + +fn prepend_underscores_to_self(ident: &mut Ident) { + if ident == "self" { + *ident = Ident::new("__self", ident.span()); + } +} diff --git a/src/project.rs b/src/project.rs index 3f4b8126..7ed4a121 100644 --- a/src/project.rs +++ b/src/project.rs @@ -2,7 +2,6 @@ use proc_macro2::TokenStream; use quote::ToTokens; use syn::{ punctuated::Punctuated, - token::Or, visit_mut::{self, VisitMut}, *, }; @@ -83,7 +82,7 @@ impl Replace for ExprIf { } } -impl Replace for Punctuated { +impl Replace for Punctuated { fn replace(&mut self, register: &mut Register) { self.iter_mut().for_each(|pat| pat.replace(register)); } @@ -155,7 +154,7 @@ impl Register { // ================================================================================================= // visitor -struct Dummy; +pub(crate) struct Dummy; impl VisitMut for Dummy { fn visit_stmt_mut(&mut self, stmt: &mut Stmt) { diff --git a/tests/project.rs b/tests/project.rs index 99a5c80b..52b9cfb1 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -6,7 +6,7 @@ #![cfg(feature = "project_attr")] use core::pin::Pin; -use pin_project::{project, unsafe_project}; +use pin_project::{pin_project, project, unsafe_project}; #[project] // Nightly does not need a dummy attribute to the function. #[test] @@ -93,3 +93,103 @@ fn test_project_attr() { } } } + +#[test] +fn test_pin_project() { + // struct + + #[unsafe_project(Unpin)] + struct Foo { + #[pin] + field1: T, + field2: U, + } + + impl Foo { + #[pin_project(self)] + fn foo(self: Pin<&mut Self>) { + let x: Pin<&mut i32> = self.field1; + assert_eq!(*x, 1); + let y: &mut i32 = self.field2; + assert_eq!(*y, 2); + } + } + + let mut foo = Foo { field1: 1, field2: 2 }; + Pin::new(&mut foo).foo(); + + // tuple struct + + #[unsafe_project(Unpin)] + struct Bar(#[pin] T, U); + + impl Bar { + #[pin_project(self)] + fn bar(self: Pin<&mut Self>) { + let x: Pin<&mut i32> = self.0; + assert_eq!(*x, 1); + let y: &mut i32 = self.1; + assert_eq!(*y, 2); + } + } + + let mut bar = Bar(1, 2); + Pin::new(&mut bar).bar(); + + // enum + + #[unsafe_project(Unpin)] + enum Baz { + Variant1(#[pin] A, B), + Variant2 { + #[pin] + field1: C, + field2: D, + }, + None, + } + + impl Baz { + #[pin_project(self, bar)] + fn baz(mut self: Pin<&mut Self>, bar: Pin<&mut Bar>) { + #[project] + match &mut self { + Baz::Variant1(x, y) => { + let x: &mut Pin<&mut i32> = x; + assert_eq!(**x, 1); + let y: &mut &mut i32 = y; + assert_eq!(**y, 2); + } + Baz::Variant2 { field1, field2 } => { + let _x: &mut Pin<&mut i32> = field1; + let _y: &mut &mut i32 = field2; + } + Baz::None => {} + } + + #[project] + { + if let Baz::Variant1(x, y) = self { + let x: Pin<&mut i32> = x; + assert_eq!(*x, 1); + + let y: &mut i32 = y; + assert_eq!(*y, 2); + } else if let Option::Some(_) = Some(1) { + // Check that don't replace different types by mistake + } + } + + #[project] + let Bar(x, y) = bar; + let _: Pin<&mut i32> = x; + assert_eq!(*x, 1); + let _: &mut i32 = y; + assert_eq!(*y, 2); + } + } + + let mut baz = Baz::Variant1(1, 2); + + Pin::new(&mut baz).baz(Pin::new(&mut bar)); +}