Skip to content

Commit

Permalink
Merge #47
Browse files Browse the repository at this point in the history
47: Make generated 'project' reference take an '&mut Pin<&mut Self>' r=taiki-e a=Aaron1011

Based on rust-lang/unsafe-code-guidelines#148 (comment)
by @CAD97

Currently, the generated 'project' method takes a 'Pin<&mut Self>',
consuming it. This makes it impossible to use the original Pin<&mut Self>
after calling project(), since the 'Pin<&mut Self>' has been moved into
the the 'Project' method.

This makes it impossible to implement useful pattern when working with
enums:

```rust

enum Foo {
    Variant1(#[pin] SomeFuture),
    Variant2(OtherType)
}

fn process(foo: Pin<&mut Foo>) {
    match foo.project() {
        __FooProjection(fut) => {
            fut.poll();
            let new_foo: Foo = ...;
            foo.set(new_foo);
        },
        _ => {}
    }
}
```

This pattern is common when implementing a Future combinator - an inner
future is polled, and then the containing enum is changed to a new
variant. However, as soon as 'project()' is called, it becoms imposible
to call 'set' on the original 'Pin<&mut Self>'.

To support this pattern, this commit changes the 'project' method to
take a '&mut Pin<&mut Self>'. The projection types work exactly as
before - however, creating it no longer requires consuming the original
'Pin<&mut Self>'

Unfortunately, current limitations of Rust prevent us from simply
modifying the signature of the 'project' method in the inherent impl
of the projection type. While using 'Pin<&mut Self>' as a receiver is
supported on stable rust, using '&mut Pin<&mut Self>' as a receiver
requires the unstable `#![feature(arbitrary_self_types)]`

For compatibility with stable Rust, we instead dynamically define a new
trait, '__{Type}ProjectionTrait', where {Type} is the name of the type
with the `#[pin_project]` attribute.

This trait looks like this:

```rust
trait __FooProjectionTrait {
    fn project(&'a mut self) -> __FooProjection<'a>;
}
```

It is then implemented for `Pin<&mut {Type}>`. This allows the `project`
method to be invoked on `&mut Pin<&mut {Type}>`, which is what we want.

If Generic Associated Types (rust-lang/rust#44265)
were implemented and stabilized, we could use a single trait for all pin
projections:

```rust
trait Projectable {
    type Projection<'a>;
    fn project(&'a mut self) -> Self::Projection<'a>;
}
```

However, Generic Associated Types are not even implemented on nightly
yet, so we need to generate a new trait per type for the foreseeable
future.

Co-authored-by: Aaron Hill <[email protected]>
  • Loading branch information
bors[bot] and Aaron1011 authored Aug 23, 2019
2 parents 62b4921 + c3be220 commit 44426ae
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 44 deletions.
16 changes: 8 additions & 8 deletions pin-project-internal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ use syn::parse::Nothing;
/// }
///
/// impl<T, U> Foo<T, U> {
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// let this = self.project();
/// let _: Pin<&mut T> = this.future; // Pinned reference to the field
/// let _: &mut U = this.field; // Normal reference to the field
Expand All @@ -115,7 +115,7 @@ use syn::parse::Nothing;
/// }
///
/// impl<T, U> Foo<T, U> {
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// let this = self.project();
/// let _: Pin<&mut T> = this.future; // Pinned reference to the field
/// let _: &mut U = this.field; // Normal reference to the field
Expand Down Expand Up @@ -162,7 +162,7 @@ use syn::parse::Nothing;
/// }
///
/// #[pinned_drop]
/// fn my_drop_fn<T: Debug, U: Debug>(foo: Pin<&mut Foo<T, U>>) {
/// fn my_drop_fn<T: Debug, U: Debug>(mut foo: Pin<&mut Foo<T, U>>) {
/// let foo = foo.project();
/// println!("Dropping pinned field: {:?}", foo.pinned_field);
/// println!("Dropping unpin field: {:?}", foo.unpin_field);
Expand Down Expand Up @@ -193,7 +193,7 @@ use syn::parse::Nothing;
/// }
///
/// impl<T, U> Foo<T, U> {
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// let this = self.project();
/// let _: Pin<&mut T> = this.future;
/// let _: &mut U = this.field;
Expand All @@ -211,7 +211,7 @@ use syn::parse::Nothing;
/// struct Foo<T, U>(#[pin] T, U);
///
/// impl<T, U> Foo<T, U> {
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// let this = self.project();
/// let _: Pin<&mut T> = this.0;
/// let _: &mut U = this.1;
Expand Down Expand Up @@ -250,7 +250,7 @@ use syn::parse::Nothing;
/// # #[cfg(feature = "project_attr")]
/// impl<A, B, C> Foo<A, B, C> {
/// #[project] // Nightly does not need a dummy attribute to the function.
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// #[project]
/// match self.project() {
/// Foo::Tuple(x, y) => {
Expand Down Expand Up @@ -347,7 +347,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream {
///
/// impl<T, U> Foo<T, U> {
/// #[project] // Nightly does not need a dummy attribute to the function.
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// #[project]
/// let Foo { future, field } = self.project();
///
Expand All @@ -372,7 +372,7 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream {
///
/// impl<A, B, C> Foo<A, B, C> {
/// #[project] // Nightly does not need a dummy attribute to the function.
/// fn baz(self: Pin<&mut Self>) {
/// fn baz(mut self: Pin<&mut Self>) {
/// #[project]
/// match self.project() {
/// Foo::Tuple(x, y) => {
Expand Down
13 changes: 7 additions & 6 deletions pin-project-internal/src/pin_project/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::utils::VecExt;

use super::{proj_generics, Context, PIN};

pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result<TokenStream> {
pub(super) fn parse(cx: &mut Context, mut item: ItemEnum) -> Result<TokenStream> {
if item.variants.is_empty() {
return Err(error!(item, "cannot be implemented for enums without variants"));
}
Expand All @@ -23,22 +23,23 @@ pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result<TokenStream>
return Err(error!(item.variants, "cannot be implemented for enums that have no field"));
}

let (proj_variants, proj_arms) = variants(&mut cx, &mut item)?;
let (proj_variants, proj_arms) = variants(cx, &mut item)?;

let impl_drop = cx.impl_drop(&item.generics);
let mut impl_drop = cx.impl_drop(&item.generics);
let Context { original, projected, lifetime, impl_unpin, .. } = cx;
let proj_generics = proj_generics(&item.generics, &lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;
let proj_trait = &cx.projected_trait;
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let mut proj_items = quote! {
enum #projected #proj_generics #where_clause { #(#proj_variants,)* }
};
let proj_method = quote! {
impl #impl_generics #original #ty_generics #where_clause {
fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #projected #proj_ty_generics {
impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #original #ty_generics> #where_clause {
fn project<#lifetime>(&#lifetime mut self) -> #projected #proj_ty_generics #where_clause {
unsafe {
match ::core::pin::Pin::get_unchecked_mut(self) {
match self.as_mut().get_unchecked_mut() {
#(#proj_arms,)*
}
}
Expand Down
56 changes: 44 additions & 12 deletions pin-project-internal/src/pin_project/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
token::Comma,
Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Index, Item, ItemStruct, Lifetime,
LifetimeDef, Meta, NestedMeta, Result, Type,
*,
};

use crate::utils::{crate_path, proj_ident};
use crate::utils::{crate_path, proj_ident, proj_trait_ident};

mod enums;
mod structs;
Expand Down Expand Up @@ -51,6 +50,10 @@ struct Context {
original: Ident,
/// Name of the projected type.
projected: Ident,
/// Name of the trait generated
/// to provide a 'project' method
projected_trait: Ident,
generics: Generics,

lifetime: Lifetime,
impl_unpin: ImplUnpin,
Expand All @@ -63,7 +66,16 @@ impl Context {
let projected = proj_ident(&original);
let lifetime = proj_lifetime(&generics.params);
let impl_unpin = ImplUnpin::new(generics, unsafe_unpin);
Ok(Self { original, projected, lifetime, impl_unpin, pinned_drop })
let projected_trait = proj_trait_ident(&original);
Ok(Self {
original,
projected,
projected_trait,
lifetime,
impl_unpin,
pinned_drop,
generics: generics.clone(),
})
}

fn impl_drop<'a>(&self, generics: &'a Generics) -> ImplDrop<'a> {
Expand All @@ -74,22 +86,42 @@ impl Context {
fn parse(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
match syn::parse2(input)? {
Item::Struct(item) => {
let cx = Context::new(args, item.ident.clone(), &item.generics)?;
let packed_check = ensure_not_packed(&item)?;
let mut res = structs::parse(cx, item)?;
res.extend(packed_check);
let mut cx = Context::new(args, item.ident.clone(), &item.generics)?;

let mut res = structs::parse(&mut cx, item.clone())?;
res.extend(ensure_not_packed(&item)?);
res.extend(make_proj_trait(&mut cx)?);
Ok(res)
}
Item::Enum(item) => {
let cx = Context::new(args, item.ident.clone(), &item.generics)?;
let mut cx = Context::new(args, item.ident.clone(), &item.generics)?;

// We don't need to check for '#[repr(packed)]',
// since it does not apply to enums
enums::parse(cx, item)
let mut res = enums::parse(&mut cx, item.clone())?;
res.extend(make_proj_trait(&mut cx)?);
Ok(res)
}
item => Err(error!(item, "may only be used on structs or enums")),
}
}

fn make_proj_trait(cx: &mut Context) -> Result<TokenStream> {
let proj_trait = &cx.projected_trait;
let lifetime = &cx.lifetime;
let proj_ident = &cx.projected;
let proj_generics = proj_generics(&cx.generics, &cx.lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;

let (orig_generics, _orig_ty_generics, orig_where_clause) = cx.generics.split_for_impl();

Ok(quote! {
trait #proj_trait #orig_generics {
fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #orig_where_clause;
}
})
}

fn ensure_not_packed(item: &ItemStruct) -> Result<TokenStream> {
for meta in item.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) {
if let Meta::List(l) = meta {
Expand Down Expand Up @@ -220,7 +252,7 @@ impl<'a> ImplDrop<'a> {
Self { generics, pinned_drop }
}

fn build(self, ident: &Ident) -> TokenStream {
fn build(&mut self, ident: &Ident) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();

if let Some(pinned_drop) = self.pinned_drop {
Expand Down Expand Up @@ -292,7 +324,7 @@ impl ImplUnpin {
}

/// Creates `Unpin` implementation.
fn build(self, ident: &Ident) -> TokenStream {
fn build(&mut self, ident: &Ident) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
quote! {
impl #impl_generics ::core::marker::Unpin for #ident #ty_generics #where_clause {}
Expand Down
15 changes: 8 additions & 7 deletions pin-project-internal/src/pin_project/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::utils::VecExt;

use super::{proj_generics, Context, PIN};

pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result<TokenStream> {
pub(super) fn parse(cx: &mut Context, mut item: ItemStruct) -> Result<TokenStream> {
let (proj_fields, proj_init) = match &mut item.fields {
Fields::Named(FieldsNamed { named: fields, .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. })
Expand All @@ -16,26 +16,27 @@ pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result<TokenStream
}
Fields::Unit => return Err(error!(item, "cannot be implemented for structs with units")),

Fields::Named(fields) => named(&mut cx, fields)?,
Fields::Unnamed(fields) => unnamed(&mut cx, fields)?,
Fields::Named(fields) => named(cx, fields)?,
Fields::Unnamed(fields) => unnamed(cx, fields)?,
};

let orig_ident = &cx.original;
let proj_ident = &cx.projected;
let lifetime = &cx.lifetime;
let impl_drop = cx.impl_drop(&item.generics);
let mut impl_drop = cx.impl_drop(&item.generics);
let proj_generics = proj_generics(&item.generics, &cx.lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;
let proj_trait = &cx.projected_trait;
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let mut proj_items = quote! {
struct #proj_ident #proj_generics #where_clause #proj_fields
};
let proj_method = quote! {
impl #impl_generics #orig_ident #ty_generics #where_clause {
fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #proj_ident #proj_ty_generics {
impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause {
fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause {
unsafe {
let this = ::core::pin::Pin::get_unchecked_mut(self);
let this = self.as_mut().get_unchecked_mut();
#proj_ident #proj_init
}
}
Expand Down
4 changes: 4 additions & 0 deletions pin-project-internal/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ pub(crate) fn proj_ident(ident: &Ident) -> Ident {
format_ident!("__{}Projection", ident)
}

pub(crate) fn proj_trait_ident(ident: &Ident) -> Ident {
format_ident!("__{}ProjectionTrait", ident)
}

pub(crate) trait VecExt {
fn find_remove(&mut self, ident: &str) -> Option<Attribute>;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! }
//!
//! impl<T, U> Foo<T, U> {
//! fn baz(self: Pin<&mut Self>) {
//! fn baz(mut self: Pin<&mut Self>) {
//! let this = self.project();
//! let _: Pin<&mut T> = this.future; // Pinned reference to the field
//! let _: &mut U = this.field; // Normal reference to the field
Expand Down
45 changes: 40 additions & 5 deletions tests/pin_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ fn test_pin_project() {

let mut foo = Foo { field1: 1, field2: 2 };

let foo = Pin::new(&mut foo).project();
let mut foo_orig = Pin::new(&mut foo);
let foo = foo_orig.project();

let x: Pin<&mut i32> = foo.field1;
assert_eq!(*x, 1);

let y: &mut i32 = foo.field2;
assert_eq!(*y, 2);

assert_eq!(foo_orig.as_ref().field1, 1);
assert_eq!(foo_orig.as_ref().field2, 2);

let mut foo = Foo { field1: 1, field2: 2 };

let foo = Pin::new(&mut foo).project();
let mut foo = Pin::new(&mut foo);
let foo = foo.project();

let __FooProjection { field1, field2 } = foo;
let _: Pin<&mut i32> = field1;
Expand All @@ -42,7 +47,8 @@ fn test_pin_project() {

let mut bar = Bar(1, 2);

let bar = Pin::new(&mut bar).project();
let mut bar = Pin::new(&mut bar);
let bar = bar.project();

let x: Pin<&mut i32> = bar.0;
assert_eq!(*x, 1);
Expand All @@ -53,6 +59,7 @@ fn test_pin_project() {
// enum

#[pin_project]
#[derive(Eq, PartialEq, Debug)]
enum Baz<A, B, C, D> {
Variant1(#[pin] A, B),
Variant2 {
Expand All @@ -65,7 +72,8 @@ fn test_pin_project() {

let mut baz = Baz::Variant1(1, 2);

let baz = Pin::new(&mut baz).project();
let mut baz_orig = Pin::new(&mut baz);
let baz = baz_orig.project();

match baz {
__BazProjection::Variant1(x, y) => {
Expand All @@ -82,9 +90,12 @@ fn test_pin_project() {
__BazProjection::None => {}
}

assert_eq!(Pin::into_ref(baz_orig).get_ref(), &Baz::Variant1(1, 2));

let mut baz = Baz::Variant2 { field1: 3, field2: 4 };

let mut baz = Pin::new(&mut baz).project();
let mut baz = Pin::new(&mut baz);
let mut baz = baz.project();

match &mut baz {
__BazProjection::Variant1(x, y) => {
Expand All @@ -110,6 +121,30 @@ fn test_pin_project() {
}
}

#[test]
fn enum_project_set() {
#[pin_project]
#[derive(Eq, PartialEq, Debug)]
enum Bar {
Variant1(#[pin] u8),
Variant2(bool),
}

let mut bar = Bar::Variant1(25);
let mut bar_orig = Pin::new(&mut bar);
let bar_proj = bar_orig.project();

match bar_proj {
__BarProjection::Variant1(val) => {
let new_bar = Bar::Variant2(val.as_ref().get_ref() == &25);
bar_orig.set(new_bar);
}
_ => unreachable!(),
}

assert_eq!(bar, Bar::Variant2(true));
}

#[test]
fn where_clause_and_associated_type_fields() {
// struct
Expand Down
2 changes: 1 addition & 1 deletion tests/pinned_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Foo<'a> {
}

#[pinned_drop]
fn do_drop(foo: Pin<&mut Foo<'_>>) {
fn do_drop(mut foo: Pin<&mut Foo<'_>>) {
**foo.project().was_dropped = true;
}

Expand Down
Loading

0 comments on commit 44426ae

Please sign in to comment.