Skip to content

Commit

Permalink
Implment #[cfg] in where clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-king committed Oct 31, 2024
1 parent 75eff9a commit 416f1eb
Show file tree
Hide file tree
Showing 48 changed files with 1,298 additions and 262 deletions.
30 changes: 25 additions & 5 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,27 @@ impl Default for WhereClause {

/// A single predicate in a where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicate {
pub struct WherePredicate {
pub attrs: AttrVec,
pub kind: WherePredicateKind,
pub id: NodeId,
}

impl WherePredicate {
pub fn with_kind(&self, kind: WherePredicateKind) -> WherePredicate {
self.map_kind(|_| kind)
}
pub fn map_kind(
&self,
f: impl FnOnce(&WherePredicateKind) -> WherePredicateKind,
) -> WherePredicate {
WherePredicate { attrs: self.attrs.clone(), kind: f(&self.kind), id: DUMMY_NODE_ID }
}
}

/// Predicate kind in where-clause.
#[derive(Clone, Encodable, Decodable, Debug)]
pub enum WherePredicateKind {
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
BoundPredicate(WhereBoundPredicate),
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
Expand All @@ -431,12 +451,12 @@ pub enum WherePredicate {
EqPredicate(WhereEqPredicate),
}

impl WherePredicate {
impl WherePredicateKind {
pub fn span(&self) -> Span {
match self {
WherePredicate::BoundPredicate(p) => p.span,
WherePredicate::RegionPredicate(p) => p.span,
WherePredicate::EqPredicate(p) => p.span,
WherePredicateKind::BoundPredicate(p) => p.span,
WherePredicateKind::RegionPredicate(p) => p.span,
WherePredicateKind::EqPredicate(p) => p.span,
}
}
}
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_ast/src/ast_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::tokenstream::LazyAttrTokenStream;
use crate::{
Arm, AssocItem, AttrItem, AttrKind, AttrVec, Attribute, Block, Crate, Expr, ExprField,
FieldDef, ForeignItem, GenericParam, Item, NodeId, Param, Pat, PatField, Path, Stmt, StmtKind,
Ty, Variant, Visibility,
Ty, Variant, Visibility, WherePredicate,
};

/// A utility trait to reduce boilerplate.
Expand Down Expand Up @@ -79,6 +79,7 @@ impl_has_node_id!(
Stmt,
Ty,
Variant,
WherePredicate,
);

impl<T: AstDeref<Target: HasNodeId>> HasNodeId for T {
Expand Down Expand Up @@ -127,7 +128,16 @@ macro_rules! impl_has_tokens_none {
}

impl_has_tokens!(AssocItem, AttrItem, Block, Expr, ForeignItem, Item, Pat, Path, Ty, Visibility);
impl_has_tokens_none!(Arm, ExprField, FieldDef, GenericParam, Param, PatField, Variant);
impl_has_tokens_none!(
Arm,
ExprField,
FieldDef,
GenericParam,
Param,
PatField,
Variant,
WherePredicate
);

impl<T: AstDeref<Target: HasTokens>> HasTokens for T {
fn tokens(&self) -> Option<&LazyAttrTokenStream> {
Expand Down Expand Up @@ -289,6 +299,7 @@ impl_has_attrs!(
Param,
PatField,
Variant,
WherePredicate,
);
impl_has_attrs_none!(Attribute, AttrItem, Block, Pat, Path, Ty, Visibility);

Expand Down
26 changes: 18 additions & 8 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,11 @@ pub trait MutVisitor: Sized {
walk_where_clause(self, where_clause);
}

fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
walk_where_predicate(self, where_predicate);
fn filter_map_where_predicate(
&mut self,
where_predicate: WherePredicate,
) -> Option<WherePredicate> {
walk_filter_map_where_predicate(self, where_predicate)
}

fn visit_vis(&mut self, vis: &mut Visibility) {
Expand Down Expand Up @@ -987,32 +990,39 @@ fn walk_ty_alias_where_clauses<T: MutVisitor>(vis: &mut T, tawcs: &mut TyAliasWh

fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
let WhereClause { has_where_token: _, predicates, span } = wc;
visit_thin_vec(predicates, |predicate| vis.visit_where_predicate(predicate));
predicates.flat_map_in_place(|predicate| vis.filter_map_where_predicate(predicate));
vis.visit_span(span);
}

fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
match pred {
WherePredicate::BoundPredicate(bp) => {
pub fn walk_filter_map_where_predicate<T: MutVisitor>(
vis: &mut T,
mut pred: WherePredicate,
) -> Option<WherePredicate> {
let WherePredicate { ref mut attrs, ref mut kind, ref mut id } = pred;
vis.visit_id(id);
visit_attrs(vis, attrs);
match kind {
WherePredicateKind::BoundPredicate(bp) => {
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
vis.visit_ty(bounded_ty);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::RegionPredicate(rp) => {
WherePredicateKind::RegionPredicate(rp) => {
let WhereRegionPredicate { span, lifetime, bounds } = rp;
vis.visit_lifetime(lifetime);
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
vis.visit_span(span);
}
WherePredicate::EqPredicate(ep) => {
WherePredicateKind::EqPredicate(ep) => {
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
vis.visit_ty(lhs_ty);
vis.visit_ty(rhs_ty);
vis.visit_span(span);
}
}
Some(pred)
}

fn walk_variant_data<T: MutVisitor>(vis: &mut T, vdata: &mut VariantData) {
Expand Down
10 changes: 6 additions & 4 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,10 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
) -> V::Result {
match predicate {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let WherePredicate { attrs, kind, id: _ } = predicate;
walk_list!(visitor, visit_attribute, attrs);
match kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bounded_ty,
bounds,
bound_generic_params,
Expand All @@ -797,11 +799,11 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
try_visit!(visitor.visit_ty(bounded_ty));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
try_visit!(visitor.visit_ty(lhs_ty));
try_visit!(visitor.visit_ty(rhs_ty));
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,9 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
}

fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
match predicate {
WherePredicate::BoundPredicate(pred) => {
self.insert(predicate.span(), predicate.hir_id, Node::WherePredicate(predicate));
match predicate.kind {
WherePredicateKind::BoundPredicate(pred) => {
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
self.with_parent(pred.hir_id, |this| {
intravisit::walk_where_predicate(this, predicate)
Expand Down
57 changes: 32 additions & 25 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
// checks both param bounds and where clauses for `?Sized`.
for pred in &generics.where_clause.predicates {
let WherePredicate::BoundPredicate(bound_pred) = pred else {
let WherePredicateKind::BoundPredicate(ref bound_pred) = pred.kind else {
continue;
};
let compute_is_param = || {
Expand Down Expand Up @@ -1538,8 +1538,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
});
let span = self.lower_span(span);

match kind {
GenericParamKind::Const { .. } => None,
let kind = match kind {
GenericParamKind::Const { .. } => return None,
GenericParamKind::Type { .. } => {
let def_id = self.local_def_id(id).to_def_id();
let hir_id = self.next_id();
Expand All @@ -1554,37 +1554,40 @@ impl<'hir> LoweringContext<'_, 'hir> {
let ty_id = self.next_id();
let bounded_ty =
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
bounded_ty: self.arena.alloc(bounded_ty),
bounds,
span,
bound_generic_params: &[],
origin,
}))
})
}
GenericParamKind::Lifetime => {
let ident = self.lower_ident(ident);
let lt_id = self.next_node_id();
let lifetime = self.new_named_lifetime(id, lt_id, ident);
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
lifetime,
span,
bounds,
in_where_clause: false,
}))
})
}
}
};
Some(hir::WherePredicate { hir_id: self.next_id(), kind: self.arena.alloc(kind) })
}

fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
match pred {
WherePredicate::BoundPredicate(WhereBoundPredicate {
let hir_id = self.lower_node_id(pred.id);
self.lower_attrs(hir_id, &pred.attrs);
let kind = match &pred.kind {
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
bound_generic_params,
bounded_ty,
bounds,
span,
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
hir_id: self.next_id(),
bound_generic_params: self
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
Expand All @@ -1597,26 +1600,30 @@ impl<'hir> LoweringContext<'_, 'hir> {
span: self.lower_span(*span),
origin: PredicateOrigin::WhereClause,
}),
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
})
}
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
WherePredicateKind::RegionPredicate(WhereRegionPredicate {
lifetime,
bounds,
span,
}) => hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
span: self.lower_span(*span),
lifetime: self.lower_lifetime(lifetime),
bounds: self.lower_param_bounds(
bounds,
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
),
in_where_clause: true,
}),
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
lhs_ty: self
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
rhs_ty: self
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
span: self.lower_span(*span),
})
}
}
};
let kind = self.arena.alloc(kind);
hir::WherePredicate { hir_id, kind }
}
}
20 changes: 10 additions & 10 deletions compiler/rustc_ast_passes/src/ast_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1202,14 +1202,14 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
validate_generic_param_order(self.dcx(), &generics.params, generics.span);

for predicate in &generics.where_clause.predicates {
if let WherePredicate::EqPredicate(predicate) = predicate {
if let WherePredicateKind::EqPredicate(ref predicate) = predicate.kind {
deny_equality_constraints(self, predicate, generics);
}
}
walk_list!(self, visit_generic_param, &generics.params);
for predicate in &generics.where_clause.predicates {
match predicate {
WherePredicate::BoundPredicate(bound_pred) => {
match predicate.kind {
WherePredicateKind::BoundPredicate(ref bound_pred) => {
// This is slightly complicated. Our representation for poly-trait-refs contains a single
// binder and thus we only allow a single level of quantification. However,
// the syntax of Rust permits quantification in two places in where clauses,
Expand Down Expand Up @@ -1593,18 +1593,18 @@ fn deny_equality_constraints(
let mut preds = generics.where_clause.predicates.iter().peekable();
// Find the predicate that shouldn't have been in the where bound list.
while let Some(pred) = preds.next() {
if let WherePredicate::EqPredicate(pred) = pred
if let WherePredicateKind::EqPredicate(ref pred) = pred.kind
&& pred.span == predicate.span
{
if let Some(next) = preds.peek() {
// This is the first predicate, remove the trailing comma as well.
span = span.with_hi(next.span().lo());
span = span.with_hi(next.kind.span().lo());
} else if let Some(prev) = prev {
// Remove the previous comma as well.
span = span.with_lo(prev.hi());
}
}
prev = Some(pred.span());
prev = Some(pred.kind.span());
}
span
};
Expand All @@ -1621,8 +1621,8 @@ fn deny_equality_constraints(
if let TyKind::Path(None, full_path) = &predicate.lhs_ty.kind {
// Given `A: Foo, Foo::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
for bounds in generics.params.iter().map(|p| &p.bounds).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p) => Some(&p.bounds),
generics.where_clause.predicates.iter().filter_map(|pred| match pred.kind {
WherePredicateKind::BoundPredicate(ref p) => Some(&p.bounds),
_ => None,
}),
) {
Expand All @@ -1645,8 +1645,8 @@ fn deny_equality_constraints(
// Given `A: Foo, A::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
if let [potential_param, potential_assoc] = &full_path.segments[..] {
for (ident, bounds) in generics.params.iter().map(|p| (p.ident, &p.bounds)).chain(
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
WherePredicate::BoundPredicate(p)
generics.where_clause.predicates.iter().filter_map(|pred| match pred.kind {
WherePredicateKind::BoundPredicate(ref p)
if let ast::TyKind::Path(None, path) = &p.bounded_ty.kind
&& let [segment] = &path.segments[..] =>
{
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,9 @@ impl<'a> Visitor<'a> for PostExpansionVisitor<'a> {

fn visit_generics(&mut self, g: &'a ast::Generics) {
for predicate in &g.where_clause.predicates {
match predicate {
ast::WherePredicate::BoundPredicate(bound_pred) => {
visit::walk_list!(self, visit_attribute, &predicate.attrs);
match predicate.kind {
ast::WherePredicateKind::BoundPredicate(ref bound_pred) => {
// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
self.check_late_bound_lifetime_defs(&bound_pred.bound_generic_params);
}
Expand Down
Loading

0 comments on commit 416f1eb

Please sign in to comment.