Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce support for async bound modifier on Fn* traits #120392

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,16 @@ pub use crate::node_id::{NodeId, CRATE_NODE_ID, DUMMY_NODE_ID};
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
pub struct TraitBoundModifiers {
pub constness: BoundConstness,
pub asyncness: BoundAsyncness,
pub polarity: BoundPolarity,
}

impl TraitBoundModifiers {
pub const NONE: Self =
Self { constness: BoundConstness::Never, polarity: BoundPolarity::Positive };
pub const NONE: Self = Self {
constness: BoundConstness::Never,
asyncness: BoundAsyncness::Normal,
polarity: BoundPolarity::Positive,
};
}

/// The AST represents all type param bounds as types.
Expand Down Expand Up @@ -2562,6 +2566,25 @@ impl BoundConstness {
}
}

/// The asyncness of a trait bound.
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
#[derive(HashStable_Generic)]
pub enum BoundAsyncness {
/// `Type: Trait`
Normal,
/// `Type: async Trait`
Async(Span),
}

impl BoundAsyncness {
pub fn as_str(self) -> &'static str {
match self {
Self::Normal => "",
Self::Async(_) => "async",
}
}
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub enum FnRetTy {
/// Returns type is not specified.
Expand Down Expand Up @@ -3300,7 +3323,7 @@ mod size_asserts {
static_assert_size!(ForeignItem, 96);
static_assert_size!(ForeignItemKind, 24);
static_assert_size!(GenericArg, 24);
static_assert_size!(GenericBound, 72);
static_assert_size!(GenericBound, 88);
static_assert_size!(Generics, 40);
static_assert_size!(Impl, 136);
static_assert_size!(Item, 136);
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_ast_lowering/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ ast_lowering_argument = argument
ast_lowering_assoc_ty_parentheses =
parenthesized generic arguments cannot be used in associated type constraints

ast_lowering_async_bound_not_on_trait =
`async` bound modifier only allowed on trait, not `{$descr}`

ast_lowering_async_bound_only_for_fn_traits =
`async` bound modifier only allowed on `Fn`/`FnMut`/`FnOnce` traits

ast_lowering_async_coroutines_not_supported =
`async` coroutines are not yet supported

Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_ast_lowering/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,18 @@ pub(crate) struct GenericParamDefaultInBinder {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(ast_lowering_async_bound_not_on_trait)]
pub(crate) struct AsyncBoundNotOnTrait {
#[primary_span]
pub span: Span,
pub descr: &'static str,
}

#[derive(Diagnostic)]
#[diag(ast_lowering_async_bound_only_for_fn_traits)]
pub(crate) struct AsyncBoundOnlyForFnTraits {
#[primary_span]
pub span: Span,
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
ParenthesizedGenericArgs::Err,
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
// Method calls can't have bound modifiers
None,
));
let receiver = self.lower_expr(receiver);
let args =
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
let itctx = ImplTraitContext::Universal;
let (generics, (trait_ref, lowered_ty)) =
self.lower_generics(ast_generics, *constness, id, &itctx, |this| {
let constness = match *constness {
Const::Yes(span) => BoundConstness::Maybe(span),
Const::No => BoundConstness::Never,
let modifiers = TraitBoundModifiers {
constness: match *constness {
Const::Yes(span) => BoundConstness::Maybe(span),
Const::No => BoundConstness::Never,
},
asyncness: BoundAsyncness::Normal,
// we don't use this in bound lowering
polarity: BoundPolarity::Positive,
};

let trait_ref = trait_ref.as_ref().map(|trait_ref| {
this.lower_trait_ref(
constness,
modifiers,
trait_ref,
&ImplTraitContext::Disallowed(ImplTraitPosition::Trait),
)
Expand Down
19 changes: 12 additions & 7 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct LoweringContext<'a, 'hir> {
allow_gen_future: Lrc<[Symbol]>,
allow_async_iterator: Lrc<[Symbol]>,
allow_for_await: Lrc<[Symbol]>,
allow_async_fn_traits: Lrc<[Symbol]>,

/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
Expand Down Expand Up @@ -176,6 +177,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
[sym::gen_future].into()
},
allow_for_await: [sym::async_iterator].into(),
allow_async_fn_traits: [sym::async_fn_traits].into(),
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
// interact with `gen`/`async gen` blocks
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
Expand Down Expand Up @@ -1311,7 +1313,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
span: t.span,
},
itctx,
ast::BoundConstness::Never,
TraitBoundModifiers::NONE,
);
let bounds = this.arena.alloc_from_iter([bound]);
let lifetime_bound = this.elided_dyn_bound(t.span);
Expand Down Expand Up @@ -1426,7 +1428,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
itctx,
// Still, don't pass along the constness here; we don't want to
// synthesize any host effect args, it'd only cause problems.
ast::BoundConstness::Never,
TraitBoundModifiers {
constness: BoundConstness::Never,
..*modifiers
},
))
}
BoundPolarity::Maybe(_) => None,
Expand Down Expand Up @@ -2019,7 +2024,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
) -> hir::GenericBound<'hir> {
match tpb {
GenericBound::Trait(p, modifiers) => hir::GenericBound::Trait(
self.lower_poly_trait_ref(p, itctx, modifiers.constness.into()),
self.lower_poly_trait_ref(p, itctx, *modifiers),
self.lower_trait_bound_modifiers(*modifiers),
),
GenericBound::Outlives(lifetime) => {
Expand Down Expand Up @@ -2192,7 +2197,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {

fn lower_trait_ref(
&mut self,
constness: ast::BoundConstness,
modifiers: ast::TraitBoundModifiers,
p: &TraitRef,
itctx: &ImplTraitContext,
) -> hir::TraitRef<'hir> {
Expand All @@ -2202,7 +2207,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&p.path,
ParamMode::Explicit,
itctx,
Some(constness),
Some(modifiers),
) {
hir::QPath::Resolved(None, path) => path,
qpath => panic!("lower_trait_ref: unexpected QPath `{qpath:?}`"),
Expand All @@ -2215,11 +2220,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&mut self,
p: &PolyTraitRef,
itctx: &ImplTraitContext,
constness: ast::BoundConstness,
modifiers: ast::TraitBoundModifiers,
) -> hir::PolyTraitRef<'hir> {
let bound_generic_params =
self.lower_lifetime_binder(p.trait_ref.ref_id, &p.bound_generic_params);
let trait_ref = self.lower_trait_ref(constness, &p.trait_ref, itctx);
let trait_ref = self.lower_trait_ref(modifiers, &p.trait_ref, itctx);
hir::PolyTraitRef { bound_generic_params, trait_ref, span: self.lower_span(p.span) }
}

Expand Down
100 changes: 90 additions & 10 deletions compiler/rustc_ast_lowering/src/path.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::ImplTraitPosition;

use super::errors::{GenericTypeWithParentheses, UseAngleBrackets};
use super::errors::{
AsyncBoundNotOnTrait, AsyncBoundOnlyForFnTraits, GenericTypeWithParentheses, UseAngleBrackets,
};
use super::ResolverAstLoweringExt;
use super::{GenericArgsCtor, LifetimeRes, ParenthesizedGenericArgs};
use super::{ImplTraitContext, LoweringContext, ParamMode};

use rustc_ast::{self as ast, *};
use rustc_data_structures::sync::Lrc;
use rustc_hir as hir;
use rustc_hir::def::{DefKind, PartialRes, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::GenericArg;
use rustc_middle::span_bug;
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::{BytePos, Span, DUMMY_SP};
use rustc_span::{BytePos, DesugaringKind, Span, Symbol, DUMMY_SP};

use smallvec::{smallvec, SmallVec};

Expand All @@ -24,8 +28,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
p: &Path,
param_mode: ParamMode,
itctx: &ImplTraitContext,
// constness of the impl/bound if this is a trait path
constness: Option<ast::BoundConstness>,
// modifiers of the impl/bound if this is a trait path
modifiers: Option<ast::TraitBoundModifiers>,
) -> hir::QPath<'hir> {
let qself_position = qself.as_ref().map(|q| q.position);
let qself = qself.as_ref().map(|q| self.lower_ty(&q.ty, itctx));
Expand All @@ -35,10 +39,36 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let base_res = partial_res.base_res();
let unresolved_segments = partial_res.unresolved_segments();

let mut res = self.lower_res(base_res);

// When we have an `async` kw on a bound, map the trait it resolves to.
let mut bound_modifier_allowed_features = None;
if let Some(TraitBoundModifiers { asyncness: BoundAsyncness::Async(_), .. }) = modifiers {
match res {
Res::Def(DefKind::Trait, def_id) => {
if let Some((async_def_id, features)) = self.map_trait_to_async_trait(def_id) {
res = Res::Def(DefKind::Trait, async_def_id);
bound_modifier_allowed_features = Some(features);
} else {
self.dcx().emit_err(AsyncBoundOnlyForFnTraits { span: p.span });
}
}
Res::Err => {
// No additional error.
}
_ => {
// This error isn't actually emitted AFAICT, but it's best to keep
// it around in case the resolver doesn't always check the defkind
// of an item or something.
self.dcx().emit_err(AsyncBoundNotOnTrait { span: p.span, descr: res.descr() });
}
}
}

let path_span_lo = p.span.shrink_to_lo();
let proj_start = p.segments.len() - unresolved_segments;
let path = self.arena.alloc(hir::Path {
res: self.lower_res(base_res),
res,
segments: self.arena.alloc_from_iter(p.segments[..proj_start].iter().enumerate().map(
|(i, segment)| {
let param_mode = match (qself_position, param_mode) {
Expand Down Expand Up @@ -77,7 +107,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
parenthesized_generic_args,
itctx,
// if this is the last segment, add constness to the trait path
if i == proj_start - 1 { constness } else { None },
if i == proj_start - 1 { modifiers.map(|m| m.constness) } else { None },
bound_modifier_allowed_features.clone(),
)
},
)),
Expand All @@ -88,6 +119,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
),
});

if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
path.span = self.mark_span_with_reason(
DesugaringKind::BoundModifier,
path.span,
Some(bound_modifier_allowed_features),
);
}

// Simple case, either no projections, or only fully-qualified.
// E.g., `std::mem::size_of` or `<I as Iterator>::Item`.
if unresolved_segments == 0 {
Expand Down Expand Up @@ -125,6 +164,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
ParenthesizedGenericArgs::Err,
itctx,
None,
None,
));
let qpath = hir::QPath::TypeRelative(ty, hir_segment);

Expand Down Expand Up @@ -166,6 +206,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
ParenthesizedGenericArgs::Err,
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
None,
)
})),
span: self.lower_span(p.span),
Expand All @@ -180,6 +221,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
parenthesized_generic_args: ParenthesizedGenericArgs,
itctx: &ImplTraitContext,
constness: Option<ast::BoundConstness>,
// Additional features ungated with a bound modifier like `async`.
// This is passed down to the implicit associated type binding in
// parenthesized bounds.
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
) -> hir::PathSegment<'hir> {
debug!("path_span: {:?}, lower_path_segment(segment: {:?})", path_span, segment);
let (mut generic_args, infer_args) = if let Some(generic_args) = segment.args.as_deref() {
Expand All @@ -188,9 +233,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.lower_angle_bracketed_parameter_data(data, param_mode, itctx)
}
GenericArgs::Parenthesized(data) => match parenthesized_generic_args {
ParenthesizedGenericArgs::ParenSugar => {
self.lower_parenthesized_parameter_data(data, itctx)
}
ParenthesizedGenericArgs::ParenSugar => self
.lower_parenthesized_parameter_data(
data,
itctx,
bound_modifier_allowed_features,
),
ParenthesizedGenericArgs::Err => {
// Suggest replacing parentheses with angle brackets `Trait(params...)` to `Trait<params...>`
let sub = if !data.inputs.is_empty() {
Expand Down Expand Up @@ -357,6 +405,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
&mut self,
data: &ParenthesizedArgs,
itctx: &ImplTraitContext,
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
) -> (GenericArgsCtor<'hir>, bool) {
// Switch to `PassThrough` mode for anonymous lifetimes; this
// means that we permit things like `&Ref<T>`, where `Ref` has
Expand Down Expand Up @@ -392,7 +441,19 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
};
let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
let binding = self.assoc_ty_binding(sym::Output, output_ty.span, output_ty);

// If we have a bound like `async Fn() -> T`, make sure that we mark the
// `Output = T` associated type bound with the right feature gates.
let mut output_span = output_ty.span;
if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
output_span = self.mark_span_with_reason(
DesugaringKind::BoundModifier,
output_span,
Some(bound_modifier_allowed_features),
);
}
let binding = self.assoc_ty_binding(sym::Output, output_span, output_ty);

(
GenericArgsCtor {
args,
Expand Down Expand Up @@ -429,4 +490,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
kind,
}
}

/// When a bound is annotated with `async`, it signals to lowering that the trait
/// that the bound refers to should be mapped to the "async" flavor of the trait.
///
/// This only needs to be done until we unify `AsyncFn` and `Fn` traits into one
/// that is generic over `async`ness, if that's ever possible, or modify the
/// lowering of `async Fn()` bounds to desugar to another trait like `LendingFn`.
fn map_trait_to_async_trait(&self, def_id: DefId) -> Option<(DefId, Lrc<[Symbol]>)> {
let lang_items = self.tcx.lang_items();
if Some(def_id) == lang_items.fn_trait() {
Some((lang_items.async_fn_trait()?, self.allow_async_fn_traits.clone()))
} else if Some(def_id) == lang_items.fn_mut_trait() {
Some((lang_items.async_fn_mut_trait()?, self.allow_async_fn_traits.clone()))
} else if Some(def_id) == lang_items.fn_once_trait() {
Some((lang_items.async_fn_once_trait()?, self.allow_async_fn_traits.clone()))
} else {
None
}
}
}
Loading
Loading