Skip to content

Commit

Permalink
Add pin_project attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed Jun 17, 2019
1 parent c8b965b commit aa74464
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ project_attr = ["syn/visit-mut"]
[dependencies]
proc-macro2 = "0.4.13"
quote = "0.6.8"
syn = { version = "0.15.22", features = ["full"] }
syn = { version = "0.15.29", features = ["full"] }
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@

extern crate proc_macro;

#[cfg(feature = "project_attr")]
mod pin_project;
#[cfg(feature = "project_attr")]
mod project;
mod unsafe_project;
Expand Down Expand Up @@ -425,3 +427,9 @@ pub fn project(args: TokenStream, input: TokenStream) -> TokenStream {
assert!(args.is_empty());
TokenStream::from(project::attribute(input.into()))
}

#[cfg(feature = "project_attr")]
#[proc_macro_attribute]
pub fn pin_project(args: TokenStream, input: TokenStream) -> TokenStream {
TokenStream::from(pin_project::attribute(args.into(), input.into()))
}
151 changes: 151 additions & 0 deletions src/pin_project.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::convert::identity;

use proc_macro2::{Group, Ident, TokenStream, TokenTree};
use quote::ToTokens;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
visit_mut::VisitMut,
*,
};

use crate::project::dummy;

pub(super) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream {
syn::parse2(input)
.and_then(|mut item| {
syn::parse2(args).map(|args: Args| {
// TODO: Integrate into `replace_item_fn`?.
dummy(&mut item);
replace_item_fn(&args.0, &mut item);
item.into_token_stream()
})
})
.map_err(|err| err.to_compile_error())
.unwrap_or_else(identity)
}

fn replace_item_fn(args: &[Ident], ItemFn { decl, block, .. }: &mut ItemFn) {
decl.inputs.iter_mut().for_each(|input| match input {
FnArg::Captured(ArgCaptured {
pat: Pat::Ident(pat @ PatIdent { subpat: None, .. }),
..
}) if args.contains(&pat.ident) => {
let mut local = Local {
attrs: Vec::new(),
let_token: token::Let::default(),
pats: Punctuated::new(),
ty: None,
init: None,
semi_token: token::Semi::default(),
};
let (local_pat, init) = if pat.ident == "self" {
ReplaceSelf.visit_block_mut(block);
let mut local_pat = pat.clone();
prepend_underscores_to_self(&mut local_pat.ident);
(local_pat, syn::parse_quote!(self.project()))
} else {
let ident = &pat.ident;
(pat.clone(), syn::parse_quote!(#ident.project()))
};
local.pats.push(Pat::Ident(local_pat));
local.init = Some((token::Eq::default(), init));
block.stmts.insert(0, Stmt::Local(local));

if pat.by_ref.is_none() {
pat.mutability = None;
}
pat.by_ref = None;
}
_ => {}
})
}

struct Args(Vec<Ident>);

impl Parse for Args {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut args = Vec::new();
let mut first = true;
while !input.is_empty() {
if first {
first = false;
} else {
let _: Token![,] = input.parse()?;
if input.is_empty() {
break;
}
}

let ident = if input.peek(Token![self]) {
let t: Token![self] = input.parse()?;
Ident::new("self", t.span)
} else {
input.parse()?
};
if args.contains(&ident) {
// TODO: error
} else {
args.push(ident);
}
}
Ok(Self(args))
}
}

// https://github.com/dtolnay/no-panic/blob/master/src/lib.rs

struct ReplaceSelf;

impl VisitMut for ReplaceSelf {
fn visit_expr_path_mut(&mut self, i: &mut ExprPath) {
if i.qself.is_none() && i.path.is_ident("self") {
prepend_underscores_to_self(&mut i.path.segments[0].ident);
}
}

fn visit_macro_mut(&mut self, i: &mut Macro) {
// We can't tell in general whether `self` inside a macro invocation
// refers to the self in the argument list or a different self
// introduced within the macro. Heuristic: if the macro input contains
// `fn`, then `self` is more likely to refer to something other than the
// outer function's self argument.
if !contains_fn(i.tts.clone()) {
i.tts = fold_token_stream(i.tts.clone());
}
}

fn visit_item_mut(&mut self, _i: &mut Item) {
// Do nothing, as `self` now means something else.
}
}

fn contains_fn(tts: TokenStream) -> bool {
tts.into_iter().any(|tt| match tt {
TokenTree::Ident(ident) => ident == "fn",
TokenTree::Group(group) => contains_fn(group.stream()),
_ => false,
})
}

fn fold_token_stream(tts: TokenStream) -> TokenStream {
tts.into_iter()
.map(|tt| match tt {
TokenTree::Ident(mut ident) => {
prepend_underscores_to_self(&mut ident);
TokenTree::Ident(ident)
}
TokenTree::Group(group) => {
let content = fold_token_stream(group.stream());
TokenTree::Group(Group::new(group.delimiter(), content))
}
other => other,
})
.collect()
}

fn prepend_underscores_to_self(ident: &mut Ident) {
if ident == "self" {
*ident = Ident::new("__self", ident.span());
}
}
3 changes: 1 addition & 2 deletions src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{
punctuated::Punctuated,
token::Or,
visit_mut::{self, VisitMut},
*,
};
Expand Down Expand Up @@ -85,7 +84,7 @@ impl Replace for ExprIf {
}
}

impl Replace for Punctuated<Pat, Or> {
impl Replace for Punctuated<Pat, token::Or> {
fn replace(&mut self, register: &mut Register) {
self.iter_mut().for_each(|pat| pat.replace(register));
}
Expand Down
104 changes: 104 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,107 @@ fn test_project_attr() {
}
}
}

#[cfg(feature = "project_attr")]
use pin_project::pin_project;

#[cfg(feature = "project_attr")]
#[test]
fn test_pin_project() {
// struct

#[unsafe_project(Unpin)]
struct Foo<T, U> {
#[pin]
field1: T,
field2: U,
}

impl Foo<i32, i32> {
#[pin_project(self)]
fn foo(self: Pin<&mut Self>) {
let x: Pin<&mut i32> = self.field1;
assert_eq!(*x, 1);
let y: &mut i32 = self.field2;
assert_eq!(*y, 2);
}
}

let mut foo = Foo { field1: 1, field2: 2 };
Pin::new(&mut foo).foo();

// tuple struct

#[unsafe_project(Unpin)]
struct Bar<T, U>(#[pin] T, U);

impl Bar<i32, i32> {
#[pin_project(self)]
fn bar(self: Pin<&mut Self>) {
let x: Pin<&mut i32> = self.0;
assert_eq!(*x, 1);
let y: &mut i32 = self.1;
assert_eq!(*y, 2);
}
}

let mut bar = Bar(1, 2);
Pin::new(&mut bar).bar();

// enum

#[unsafe_project(Unpin)]
enum Baz<A, B, C, D> {
Variant1(#[pin] A, B),
Variant2 {
#[pin]
field1: C,
field2: D,
},
None,
}

impl Baz<i32, i32, i32, i32> {
#[pin_project(self, bar)]
fn baz(mut self: Pin<&mut Self>, bar: Pin<&mut Bar<i32, i32>>) {
#[project]
match &mut self {
Baz::Variant1(x, y) => {
let x: &mut Pin<&mut i32> = x;
assert_eq!(**x, 1);
let y: &mut &mut i32 = y;
assert_eq!(**y, 2);
}
Baz::Variant2 { field1, field2 } => {
let _x: &mut Pin<&mut i32> = field1;
let _y: &mut &mut i32 = field2;
}
Baz::None => {}
}

#[project]
{
if let Baz::Variant1(x, y) = self {
let x: Pin<&mut i32> = x;
assert_eq!(*x, 1);

let y: &mut i32 = y;
assert_eq!(*y, 2);
} else if let Option::Some(_) = Some(1) {
// Check that don't replace different types by mistake
}
}

#[project]
let Bar(x, y) = bar;
let _: Pin<&mut i32> = x;
assert_eq!(*x, 1);
let _: &mut i32 = y;
assert_eq!(*y, 2);
}
}

let mut baz = Baz::Variant1(1, 2);

Pin::new(&mut baz).baz(Pin::new(&mut bar));
}

0 comments on commit aa74464

Please sign in to comment.