Skip to content

Commit

Permalink
Refactor where predicates, and reserve for attributes support
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-king committed Nov 11, 2024
1 parent 328b759 commit 2564d04
Show file tree
Hide file tree
Showing 46 changed files with 491 additions and 375 deletions.
37 changes: 31 additions & 6 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use rustc_macros::{Decodable, Encodable, HashStable_Generic};
pub use rustc_span::AttrId;
use rustc_span::source_map::{Spanned, respan};
use rustc_span::symbol::{Ident, Symbol, kw, sym};
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span, SyntaxContext};
use thin_vec::{ThinVec, thin_vec};

pub use crate::format::*;
Expand Down Expand Up @@ -428,7 +428,32 @@ impl Default for WhereClause {

/// A single predicate in a where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicate {
pub struct WherePredicate {
pub kind: WherePredicateKind,
pub id: NodeId,
pub span: Span,
}

impl WherePredicate {
pub fn with_kind(&self, kind: WherePredicateKind) -> WherePredicate {
self.map_kind(None, |_| kind)
}
pub fn map_kind(
&self,
ctxt: Option<SyntaxContext>,
f: impl FnOnce(&WherePredicateKind) -> WherePredicateKind,
) -> WherePredicate {
WherePredicate {
kind: f(&self.kind),
id: DUMMY_NODE_ID,
span: ctxt.map_or(self.span, |ctxt| self.span.with_ctxt(ctxt)),
}
}
}

/// Predicate kind in where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicateKind {
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
BoundPredicate(WhereBoundPredicate),
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
Expand All @@ -437,12 +462,12 @@ pub enum WherePredicate {
EqPredicate(WhereEqPredicate),
}

impl WherePredicate {
impl WherePredicateKind {
pub fn span(&self) -> Span {
match self {
WherePredicate::BoundPredicate(p) => p.span,
WherePredicate::RegionPredicate(p) => p.span,
WherePredicate::EqPredicate(p) => p.span,
WherePredicateKind::BoundPredicate(p) => p.span,
WherePredicateKind::RegionPredicate(p) => p.span,
WherePredicateKind::EqPredicate(p) => p.span,
}
}
}
Expand Down
16 changes: 13 additions & 3 deletions compiler/rustc_ast/src/ast_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::tokenstream::LazyAttrTokenStream;
use crate::{
Arm, AssocItem, AttrItem, AttrKind, AttrVec, Attribute, Block, Crate, Expr, ExprField,
FieldDef, ForeignItem, GenericParam, Item, NodeId, Param, Pat, PatField, Path, Stmt, StmtKind,
Ty, Variant, Visibility,
Ty, Variant, Visibility, WherePredicate,
};

/// A utility trait to reduce boilerplate.
Expand Down Expand Up @@ -79,6 +79,7 @@ impl_has_node_id!(
Stmt,
Ty,
Variant,
WherePredicate,
);

impl<T: AstDeref<Target: HasNodeId>> HasNodeId for T {
Expand Down Expand Up @@ -127,7 +128,16 @@ macro_rules! impl_has_tokens_none {
}

impl_has_tokens!(AssocItem, AttrItem, Block, Expr, ForeignItem, Item, Pat, Path, Ty, Visibility);
impl_has_tokens_none!(Arm, ExprField, FieldDef, GenericParam, Param, PatField, Variant);
impl_has_tokens_none!(
Arm,
ExprField,
FieldDef,
GenericParam,
Param,
PatField,
Variant,
WherePredicate
);

impl<T: AstDeref<Target: HasTokens>> HasTokens for T {
fn tokens(&self) -> Option<&LazyAttrTokenStream> {
Expand Down Expand Up @@ -290,7 +300,7 @@ impl_has_attrs!(
PatField,
Variant,
);
impl_has_attrs_none!(Attribute, AttrItem, Block, Pat, Path, Ty, Visibility);
impl_has_attrs_none!(Attribute, AttrItem, Block, Pat, Path, Ty, Visibility, WherePredicate);

impl<T: AstDeref<Target: HasAttrs>> HasAttrs for T {
const SUPPORTS_CUSTOM_INNER_ATTRS: bool = T::Target::SUPPORTS_CUSTOM_INNER_ATTRS;
Expand Down
34 changes: 26 additions & 8 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,15 @@ pub trait MutVisitor: Sized {
walk_where_clause(self, where_clause);
}

fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
walk_where_predicate(self, where_predicate);
fn flat_map_where_predicate(
&mut self,
where_predicate: WherePredicate,
) -> SmallVec<[WherePredicate; 1]> {
walk_flat_map_where_predicate(self, where_predicate)
}

fn visit_where_predicate_kind(&mut self, kind: &mut WherePredicateKind) {
walk_where_predicate_kind(self, kind)
}

fn visit_vis(&mut self, vis: &mut Visibility) {
Expand Down Expand Up @@ -987,26 +994,37 @@ fn walk_ty_alias_where_clauses<T: MutVisitor>(vis: &mut T, tawcs: &mut TyAliasWh

fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
let WhereClause { has_where_token: _, predicates, span } = wc;
visit_thin_vec(predicates, |predicate| vis.visit_where_predicate(predicate));
predicates.flat_map_in_place(|predicate| vis.flat_map_where_predicate(predicate));
vis.visit_span(span);
}

fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
match pred {
WherePredicate::BoundPredicate(bp) => {
pub fn walk_flat_map_where_predicate<T: MutVisitor>(
vis: &mut T,
mut pred: WherePredicate,
) -> SmallVec<[WherePredicate; 1]> {
let WherePredicate { ref mut kind, ref mut id, ref mut span } = pred;
vis.visit_id(id);
vis.visit_where_predicate_kind(kind);
vis.visit_span(span);
smallvec![pred]
}

pub fn walk_where_predicate_kind<T: MutVisitor>(vis: &mut T, kind: &mut WherePredicateKind) {
match kind {
WherePredicateKind::BoundPredicate(bp) => {
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
vis.visit_ty(bounded_ty);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::RegionPredicate(rp) => {
WherePredicateKind::RegionPredicate(rp) => {
let WhereRegionPredicate { span, lifetime, bounds } = rp;
vis.visit_lifetime(lifetime);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::EqPredicate(ep) => {
WherePredicateKind::EqPredicate(ep) => {
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
vis.visit_ty(lhs_ty);
vis.visit_ty(rhs_ty);
Expand Down
19 changes: 15 additions & 4 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
walk_where_predicate(self, p)
}
fn visit_where_predicate_kind(&mut self, k: &'ast WherePredicateKind) -> Self::Result {
walk_where_predicate_kind(self, k)
}
fn visit_fn(&mut self, fk: FnKind<'ast>, _: Span, _: NodeId) -> Self::Result {
walk_fn(self, fk)
}
Expand Down Expand Up @@ -786,8 +789,16 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
) -> V::Result {
match predicate {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let WherePredicate { kind, id: _, span: _ } = predicate;
visitor.visit_where_predicate_kind(kind)
}

pub fn walk_where_predicate_kind<'a, V: Visitor<'a>>(
visitor: &mut V,
kind: &'a WherePredicateKind,
) -> V::Result {
match kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bounded_ty,
bounds,
bound_generic_params,
Expand All @@ -797,11 +808,11 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
try_visit!(visitor.visit_ty(bounded_ty));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
try_visit!(visitor.visit_ty(lhs_ty));
try_visit!(visitor.visit_ty(rhs_ty));
}
Expand Down
13 changes: 4 additions & 9 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,10 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
}

fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
match predicate {
WherePredicate::BoundPredicate(pred) => {
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
self.with_parent(pred.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
})
}
_ => intravisit::walk_where_predicate(self, predicate),
}
self.insert(predicate.span, predicate.hir_id, Node::WherePredicate(predicate));
self.with_parent(predicate.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
});
}

fn visit_array_length(&mut self, len: &'hir ArrayLen<'hir>) {
Expand Down
59 changes: 32 additions & 27 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
// checks both param bounds and where clauses for `?Sized`.
for pred in &generics.where_clause.predicates {
let WherePredicate::BoundPredicate(bound_pred) = pred else {
let WherePredicateKind::BoundPredicate(ref bound_pred) = pred.kind else {
continue;
};
let compute_is_param = || {
Expand Down Expand Up @@ -1538,8 +1538,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
});
let span = self.lower_span(span);

match kind {
GenericParamKind::Const { .. } => None,
let kind = match kind {
GenericParamKind::Const { .. } => return None,
GenericParamKind::Type { .. } => {
let def_id = self.local_def_id(id).to_def_id();
let hir_id = self.next_id();
Expand All @@ -1554,38 +1554,38 @@ impl<'hir> LoweringContext<'_, 'hir> {
let ty_id = self.next_id();
let bounded_ty =
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bounded_ty: self.arena.alloc(bounded_ty),
bounds,
span,
bound_generic_params: &[],
origin,
}))
})
}
GenericParamKind::Lifetime => {
let ident = self.lower_ident(ident);
let lt_id = self.next_node_id();
let lifetime = self.new_named_lifetime(id, lt_id, ident);
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
lifetime,
span,
bounds,
in_where_clause: false,
}))
})
}
}
};
Some(hir::WherePredicate { hir_id: self.next_id(), kind: self.arena.alloc(kind), span })
}

fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
match pred {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let hir_id = self.lower_node_id(pred.id);
let kind = match &pred.kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bound_generic_params,
bounded_ty,
bounds,
span,
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
bound_generic_params: self
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
bounded_ty: self
Expand All @@ -1597,26 +1597,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
span: self.lower_span(*span),
origin: PredicateOrigin::WhereClause,
}),
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
})
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
WherePredicateKind::RegionPredicate(WhereRegionPredicate {
lifetime,
bounds,
span,
}) => hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
}),
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
lhs_ty: self
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
rhs_ty: self
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
span: self.lower_span(*span),
})
}
}
};
let kind = self.arena.alloc(kind);
let span = self.lower_span(pred.span);
hir::WherePredicate { hir_id, kind, span }
}
}
Loading

0 comments on commit 2564d04

Please sign in to comment.