Skip to content

Commit

Permalink
Auto merge of #120392 - compiler-errors:async-bound-modifier, r=<try>
Browse files Browse the repository at this point in the history
Introduce support for `async` bound modifier on `Fn*` traits

Adds `async` to the list of `TraitBoundModifiers`, which instructs AST lowering to map the trait to an async flavor of the trait. For now, this is only supported for `Fn*` to `AsyncFn*`, and I expect that this manual mapping via lang items will be replaced with a better system in the future.

The motivation for adding these bounds is to separate the users of async closures from the exact trait desugaring of their callable bounds. Instead of users needing to be concerned with the `AsyncFn` trait, they should be able to write `async Fn()` and it will desugar to whatever underlying trait we decide is best for the lowering of async closures.
  • Loading branch information
bors committed Jan 26, 2024
2 parents e7bbe8c + d2bb2cc commit cb0d84e
Show file tree
Hide file tree
Showing 29 changed files with 500 additions and 57 deletions.
29 changes: 26 additions & 3 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,16 @@ pub use crate::node_id::{NodeId, CRATE_NODE_ID, DUMMY_NODE_ID};
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
pub struct TraitBoundModifiers {
pub constness: BoundConstness,
pub asyncness: BoundAsyncness,
pub polarity: BoundPolarity,
}

impl TraitBoundModifiers {
pub const NONE: Self =
Self { constness: BoundConstness::Never, polarity: BoundPolarity::Positive };
pub const NONE: Self = Self {
constness: BoundConstness::Never,
asyncness: BoundAsyncness::Normal,
polarity: BoundPolarity::Positive,
};
}

/// The AST represents all type param bounds as types.
Expand Down Expand Up @@ -2562,6 +2566,25 @@ impl BoundConstness {
}
}

/// The asyncness of a trait bound.
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
#[derive(HashStable_Generic)]
pub enum BoundAsyncness {
// `Type: Trait`
Normal,
// `Type: async Trait`
Async(Span),
}

impl BoundAsyncness {
pub fn as_str(self) -> &'static str {
match self {
Self::Normal => "",
Self::Async(_) => "async",
}
}
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub enum FnRetTy {
/// Returns type is not specified.
Expand Down Expand Up @@ -3300,7 +3323,7 @@ mod size_asserts {
static_assert_size!(ForeignItem, 96);
static_assert_size!(ForeignItemKind, 24);
static_assert_size!(GenericArg, 24);
static_assert_size!(GenericBound, 72);
static_assert_size!(GenericBound, 88);
static_assert_size!(Generics, 40);
static_assert_size!(Impl, 136);
static_assert_size!(Item, 136);
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_ast_lowering/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ ast_lowering_argument = argument
ast_lowering_assoc_ty_parentheses =
parenthesized generic arguments cannot be used in associated type constraints
ast_lowering_async_bound_not_on_trait =
`async` bound modifier only allowed on trait, not `{$descr}`
ast_lowering_async_bound_only_for_fn_traits =
`async` bound modifier only allowed on `Fn`/`FnMut`/`FnOnce` traits
ast_lowering_async_coroutines_not_supported =
`async` coroutines are not yet supported
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_ast_lowering/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,18 @@ pub(crate) struct GenericParamDefaultInBinder {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(ast_lowering_async_bound_not_on_trait)]
pub(crate) struct AsyncBoundNotOnTrait {
#[primary_span]
pub span: Span,
pub descr: &'static str,
}

#[derive(Diagnostic)]
#[diag(ast_lowering_async_bound_only_for_fn_traits)]
pub(crate) struct AsyncBoundOnlyForFnTraits {
#[primary_span]
pub span: Span,
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
ParenthesizedGenericArgs::Err,
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
// Method calls can't have bound modifiers
None,
));
let receiver = self.lower_expr(receiver);
let args =
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
let itctx = ImplTraitContext::Universal;
let (generics, (trait_ref, lowered_ty)) =
self.lower_generics(ast_generics, *constness, id, &itctx, |this| {
let constness = match *constness {
Const::Yes(span) => BoundConstness::Maybe(span),
Const::No => BoundConstness::Never,
let modifiers = TraitBoundModifiers {
constness: match *constness {
Const::Yes(span) => BoundConstness::Maybe(span),
Const::No => BoundConstness::Never,
},
asyncness: BoundAsyncness::Normal,
// we don't use this in bound lowering
polarity: BoundPolarity::Positive,
};

let trait_ref = trait_ref.as_ref().map(|trait_ref| {
this.lower_trait_ref(
constness,
modifiers,
trait_ref,
&ImplTraitContext::Disallowed(ImplTraitPosition::Trait),
)
Expand Down
19 changes: 12 additions & 7 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct LoweringContext<'a, 'hir> {
allow_gen_future: Lrc<[Symbol]>,
allow_async_iterator: Lrc<[Symbol]>,
allow_for_await: Lrc<[Symbol]>,
allow_async_fn_traits: Lrc<[Symbol]>,

/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
Expand Down Expand Up @@ -176,6 +177,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
[sym::gen_future].into()
},
allow_for_await: [sym::async_iterator].into(),
allow_async_fn_traits: [sym::async_fn_traits].into(),
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
// interact with `gen`/`async gen` blocks
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
Expand Down Expand Up @@ -1311,7 +1313,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
span: t.span,
},
itctx,
ast::BoundConstness::Never,
TraitBoundModifiers::NONE,
);
let bounds = this.arena.alloc_from_iter([bound]);
let lifetime_bound = this.elided_dyn_bound(t.span);
Expand Down Expand Up @@ -1426,7 +1428,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
itctx,
// Still, don't pass along the constness here; we don't want to
// synthesize any host effect args, it'd only cause problems.
ast::BoundConstness::Never,
TraitBoundModifiers {
constness: BoundConstness::Never,
..*modifiers
},
))
}
BoundPolarity::Maybe(_) => None,
Expand Down Expand Up @@ -2019,7 +2024,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
) -> hir::GenericBound<'hir> {
match tpb {
GenericBound::Trait(p, modifiers) => hir::GenericBound::Trait(
self.lower_poly_trait_ref(p, itctx, modifiers.constness.into()),
self.lower_poly_trait_ref(p, itctx, *modifiers),
self.lower_trait_bound_modifiers(*modifiers),
),
GenericBound::Outlives(lifetime) => {
Expand Down Expand Up @@ -2192,7 +2197,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {

fn lower_trait_ref(
&mut self,
constness: ast::BoundConstness,
modifiers: ast::TraitBoundModifiers,
p: &TraitRef,
itctx: &ImplTraitContext,
) -> hir::TraitRef<'hir> {
Expand All @@ -2202,7 +2207,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&p.path,
ParamMode::Explicit,
itctx,
Some(constness),
Some(modifiers),
) {
hir::QPath::Resolved(None, path) => path,
qpath => panic!("lower_trait_ref: unexpected QPath `{qpath:?}`"),
Expand All @@ -2215,11 +2220,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&mut self,
p: &PolyTraitRef,
itctx: &ImplTraitContext,
constness: ast::BoundConstness,
modifiers: ast::TraitBoundModifiers,
) -> hir::PolyTraitRef<'hir> {
let bound_generic_params =
self.lower_lifetime_binder(p.trait_ref.ref_id, &p.bound_generic_params);
let trait_ref = self.lower_trait_ref(constness, &p.trait_ref, itctx);
let trait_ref = self.lower_trait_ref(modifiers, &p.trait_ref, itctx);
hir::PolyTraitRef { bound_generic_params, trait_ref, span: self.lower_span(p.span) }
}

Expand Down
100 changes: 90 additions & 10 deletions compiler/rustc_ast_lowering/src/path.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::ImplTraitPosition;

use super::errors::{GenericTypeWithParentheses, UseAngleBrackets};
use super::errors::{
AsyncBoundNotOnTrait, AsyncBoundOnlyForFnTraits, GenericTypeWithParentheses, UseAngleBrackets,
};
use super::ResolverAstLoweringExt;
use super::{GenericArgsCtor, LifetimeRes, ParenthesizedGenericArgs};
use super::{ImplTraitContext, LoweringContext, ParamMode};

use rustc_ast::{self as ast, *};
use rustc_data_structures::sync::Lrc;
use rustc_hir as hir;
use rustc_hir::def::{DefKind, PartialRes, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::GenericArg;
use rustc_middle::span_bug;
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::{BytePos, Span, DUMMY_SP};
use rustc_span::{BytePos, DesugaringKind, Span, Symbol, DUMMY_SP};

use smallvec::{smallvec, SmallVec};

Expand All @@ -24,8 +28,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
p: &Path,
param_mode: ParamMode,
itctx: &ImplTraitContext,
// constness of the impl/bound if this is a trait path
constness: Option<ast::BoundConstness>,
// modifiers of the impl/bound if this is a trait path
modifiers: Option<ast::TraitBoundModifiers>,
) -> hir::QPath<'hir> {
let qself_position = qself.as_ref().map(|q| q.position);
let qself = qself.as_ref().map(|q| self.lower_ty(&q.ty, itctx));
Expand All @@ -35,10 +39,36 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let base_res = partial_res.base_res();
let unresolved_segments = partial_res.unresolved_segments();

let mut res = self.lower_res(base_res);

// When we have an `async` kw on a bound, map the trait it resolves to.
let mut bound_modifier_allowed_features = None;
if let Some(TraitBoundModifiers { asyncness: BoundAsyncness::Async(_), .. }) = modifiers {
match res {
Res::Def(DefKind::Trait, def_id) => {
if let Some((async_def_id, features)) = self.map_trait_to_async_trait(def_id) {
res = Res::Def(DefKind::Trait, async_def_id);
bound_modifier_allowed_features = Some(features);
} else {
self.dcx().emit_err(AsyncBoundOnlyForFnTraits { span: p.span });
}
}
Res::Err => {
// No additional error.
}
_ => {
// This error isn't actually emitted AFAICT, but it's best to keep
// it around in case the resolver doesn't always check the defkind
// of an item or something.
self.dcx().emit_err(AsyncBoundNotOnTrait { span: p.span, descr: res.descr() });
}
}
}

let path_span_lo = p.span.shrink_to_lo();
let proj_start = p.segments.len() - unresolved_segments;
let path = self.arena.alloc(hir::Path {
res: self.lower_res(base_res),
res,
segments: self.arena.alloc_from_iter(p.segments[..proj_start].iter().enumerate().map(
|(i, segment)| {
let param_mode = match (qself_position, param_mode) {
Expand Down Expand Up @@ -77,7 +107,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
parenthesized_generic_args,
itctx,
// if this is the last segment, add constness to the trait path
if i == proj_start - 1 { constness } else { None },
if i == proj_start - 1 { modifiers.map(|m| m.constness) } else { None },
bound_modifier_allowed_features.clone(),
)
},
)),
Expand All @@ -88,6 +119,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
),
});

if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
path.span = self.mark_span_with_reason(
DesugaringKind::BoundModifier,
path.span,
Some(bound_modifier_allowed_features),
);
}

// Simple case, either no projections, or only fully-qualified.
// E.g., `std::mem::size_of` or `<I as Iterator>::Item`.
if unresolved_segments == 0 {
Expand Down Expand Up @@ -125,6 +164,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
ParenthesizedGenericArgs::Err,
itctx,
None,
None,
));
let qpath = hir::QPath::TypeRelative(ty, hir_segment);

Expand Down Expand Up @@ -166,6 +206,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
ParenthesizedGenericArgs::Err,
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
None,
)
})),
span: self.lower_span(p.span),
Expand All @@ -180,6 +221,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
parenthesized_generic_args: ParenthesizedGenericArgs,
itctx: &ImplTraitContext,
constness: Option<ast::BoundConstness>,
// Additional features ungated with a bound modifier like `async`.
// This is passed down to the implicit associated type binding in
// parenthesized bounds.
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
) -> hir::PathSegment<'hir> {
debug!("path_span: {:?}, lower_path_segment(segment: {:?})", path_span, segment);
let (mut generic_args, infer_args) = if let Some(generic_args) = segment.args.as_deref() {
Expand All @@ -188,9 +233,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.lower_angle_bracketed_parameter_data(data, param_mode, itctx)
}
GenericArgs::Parenthesized(data) => match parenthesized_generic_args {
ParenthesizedGenericArgs::ParenSugar => {
self.lower_parenthesized_parameter_data(data, itctx)
}
ParenthesizedGenericArgs::ParenSugar => self
.lower_parenthesized_parameter_data(
data,
itctx,
bound_modifier_allowed_features,
),
ParenthesizedGenericArgs::Err => {
// Suggest replacing parentheses with angle brackets `Trait(params...)` to `Trait<params...>`
let sub = if !data.inputs.is_empty() {
Expand Down Expand Up @@ -357,6 +405,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&mut self,
data: &ParenthesizedArgs,
itctx: &ImplTraitContext,
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
) -> (GenericArgsCtor<'hir>, bool) {
// Switch to `PassThrough` mode for anonymous lifetimes; this
// means that we permit things like `&Ref<T>`, where `Ref` has
Expand Down Expand Up @@ -392,7 +441,19 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
};
let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
let binding = self.assoc_ty_binding(sym::Output, output_ty.span, output_ty);

// If we have a bound like `async Fn() -> T`, make sure that we mark the
// `Output = T` associated type bound with the right feature gates.
let mut output_span = output_ty.span;
if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
output_span = self.mark_span_with_reason(
DesugaringKind::BoundModifier,
output_span,
Some(bound_modifier_allowed_features),
);
}
let binding = self.assoc_ty_binding(sym::Output, output_span, output_ty);

(
GenericArgsCtor {
args,
Expand Down Expand Up @@ -429,4 +490,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
kind,
}
}

/// When a bound is annotated with `async`, it signals to lowering that the trait
/// that the bound refers to should be mapped to the "async" flavor of the trait.
///
/// This only needs to be done until we unify `AsyncFn` and `Fn` traits into one
/// that is generic over `async`ness, if that's ever possible, or modify the
/// lowering of `async Fn()` bounds to desugar to another trait like `LendingFn`.
fn map_trait_to_async_trait(&self, def_id: DefId) -> Option<(DefId, Lrc<[Symbol]>)> {
let lang_items = self.tcx.lang_items();
if Some(def_id) == lang_items.fn_trait() {
Some((lang_items.async_fn_trait()?, self.allow_async_fn_traits.clone()))
} else if Some(def_id) == lang_items.fn_mut_trait() {
Some((lang_items.async_fn_mut_trait()?, self.allow_async_fn_traits.clone()))
} else if Some(def_id) == lang_items.fn_once_trait() {
Some((lang_items.async_fn_once_trait()?, self.allow_async_fn_traits.clone()))
} else {
None
}
}
}
1 change: 1 addition & 0 deletions compiler/rustc_expand/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ impl<'a> ExtCtxt<'a> {
} else {
ast::BoundConstness::Never
},
asyncness: ast::BoundAsyncness::Normal,
},
)
}
Expand Down
Loading

0 comments on commit cb0d84e

Please sign in to comment.