Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more eagerly instantiate binders #119849

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ pub enum SelectionError<'tcx> {
OpaqueTypeAutoTraitLeakageUnknown(DefId),
}

// FIXME(@lcnr): The `Binder` here should be unnecessary. Just use `TraitRef` instead.
#[derive(Clone, Debug, TypeVisitable)]
pub struct SignatureMismatchData<'tcx> {
pub found_trait_ref: ty::PolyTraitRef<'tcx>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3409,6 +3409,8 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
self.dcx().try_steal_replace_and_emit_err(self.tcx.def_span(def_id), StashKey::Cycle, err)
}

// FIXME(@lcnr): This function could be changed to trait `TraitRef` directly
// instead of using a `Binder`.
fn report_signature_mismatch_error(
&self,
obligation: &PredicateObligation<'tcx>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let poly_trait_predicate = self.infcx.resolve_vars_if_possible(obligation.predicate);
let placeholder_trait_predicate =
self.infcx.enter_forall_and_leak_universe(poly_trait_predicate);
debug!(?placeholder_trait_predicate);

// The bounds returned by `item_bounds` may contain duplicates after
// normalization, so try to deduplicate when possible to avoid
Expand All @@ -184,8 +183,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
selcx.infcx.probe(|_| {
match selcx.match_normalize_trait_ref(
obligation,
bound.to_poly_trait_ref(),
placeholder_trait_predicate.trait_ref,
bound.to_poly_trait_ref(),
) {
Ok(None) => {
candidates.vec.push(ProjectionCandidate(idx));
Expand Down Expand Up @@ -881,8 +880,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
self.infcx.probe(|_| {
self.match_normalize_trait_ref(
obligation,
upcast_trait_ref,
placeholder_trait_predicate.trait_ref,
upcast_trait_ref,
)
.is_ok()
})
Expand Down
35 changes: 24 additions & 11 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use rustc_ast::Mutability;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir::lang_items::LangItem;
use rustc_infer::infer::BoundRegionConversionTime::HigherRankedType;
use rustc_infer::infer::HigherRankedType;
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::traits::{BuiltinImplSource, SignatureMismatchData};
use rustc_middle::ty::{
Expand Down Expand Up @@ -161,8 +161,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let placeholder_trait_predicate =
self.infcx.enter_forall_and_leak_universe(trait_predicate).trait_ref;
let placeholder_self_ty = placeholder_trait_predicate.self_ty();
let placeholder_trait_predicate = ty::Binder::dummy(placeholder_trait_predicate);

let candidate_predicate = self
.for_each_item_bound(
placeholder_self_ty,
Expand All @@ -182,6 +180,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
.expect("projection candidate is not a trait predicate")
.map_bound(|t| t.trait_ref);

let candidate = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
candidate,
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By instantiating before we normalize, we allow more projections to be replaced by infer vars if their normalization is ambiguous, e.g. for<'a> <?0 as Trait<'a>>::Assoc stays as is while <?0 as Trait<'!a>>::Assoc gets replaced with an inference variable ?term and result in a nested Projection(<?0 as Trait<'!a>>::Assoc, ?term) goal.

let mut obligations = Vec::new();
let candidate = normalize_with_depth_to(
self,
Expand All @@ -195,7 +198,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
obligations.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, placeholder_trait_predicate, candidate)
.eq(DefineOpaqueTypes::No, placeholder_trait_predicate, candidate)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| Unimplemented)?,
);
Expand Down Expand Up @@ -499,7 +502,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let trait_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(trait_predicate.self_ty());
let obligation_trait_ref = ty::Binder::dummy(trait_predicate.trait_ref);
let ty::Dynamic(data, ..) = *self_ty.kind() else {
span_bug!(obligation.cause.span, "object candidate with non-object");
};
Expand All @@ -520,19 +522,24 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let unnormalized_upcast_trait_ref =
supertraits.nth(index).expect("supertraits iterator no longer has as many elements");

let upcast_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
unnormalized_upcast_trait_ref,
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can allow additional normalization

let upcast_trait_ref = normalize_with_depth_to(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
unnormalized_upcast_trait_ref,
upcast_trait_ref,
&mut nested,
);

nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation_trait_ref, upcast_trait_ref)
.eq(DefineOpaqueTypes::No, trait_predicate.trait_ref, upcast_trait_ref)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| Unimplemented)?,
);
Expand Down Expand Up @@ -1021,7 +1028,13 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
obligation: &PolyTraitObligation<'tcx>,
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let obligation_trait_ref = obligation.predicate.to_poly_trait_ref();
let obligation_trait_ref =
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
self_ty_trait_ref,
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

// Normalize the obligation and expected trait refs together, because why not
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
ensure_sufficient_stack(|| {
Expand All @@ -1037,15 +1050,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.map(|InferOk { mut obligations, .. }| {
obligations.extend(nested);
obligations
})
.map_err(|terr| {
SignatureMismatch(Box::new(SignatureMismatchData {
expected_trait_ref: obligation_trait_ref,
found_trait_ref: expected_trait_ref,
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
terr,
}))
})
Expand Down
29 changes: 20 additions & 9 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use rustc_errors::{Diag, EmissionGuarantee};
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::BoundRegionConversionTime;
use rustc_infer::infer::BoundRegionConversionTime::HigherRankedType;
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::traits::TraitObligation;
use rustc_middle::dep_graph::dep_kinds;
Expand All @@ -42,7 +43,7 @@ use rustc_middle::ty::_match::MatchAgainstFreshVars;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
use rustc_middle::ty::relate::TypeRelation;
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::{self, PolyProjectionPredicate, ToPolyTraitRef, ToPredicate};
use rustc_middle::ty::{self, PolyProjectionPredicate, ToPredicate};
use rustc_middle::ty::{Ty, TyCtxt, TypeFoldable, TypeVisitableExt};
use rustc_span::symbol::sym;
use rustc_span::Symbol;
Expand Down Expand Up @@ -1651,15 +1652,20 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
fn match_normalize_trait_ref(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
trait_bound: ty::PolyTraitRef<'tcx>,
placeholder_trait_ref: ty::TraitRef<'tcx>,
) -> Result<Option<ty::PolyTraitRef<'tcx>>, ()> {
trait_bound: ty::PolyTraitRef<'tcx>,
) -> Result<Option<ty::TraitRef<'tcx>>, ()> {
debug_assert!(!placeholder_trait_ref.has_escaping_bound_vars());
if placeholder_trait_ref.def_id != trait_bound.def_id() {
// Avoid unnecessary normalization
return Err(());
}

let trait_bound = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
trait_bound,
);
let Normalized { value: trait_bound, obligations: _ } = ensure_sufficient_stack(|| {
normalize_with_depth(
self,
Expand All @@ -1671,7 +1677,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
});
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, ty::Binder::dummy(placeholder_trait_ref), trait_bound)
.eq(DefineOpaqueTypes::No, placeholder_trait_ref, trait_bound)
.map(|InferOk { obligations: _, value: () }| {
// This method is called within a probe, so we can't have
// inference variables and placeholders escape.
Expand All @@ -1683,7 +1689,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
})
.map_err(|_| ())
}

fn where_clause_may_apply<'o>(
&mut self,
stack: &TraitObligationStack<'o, 'tcx>,
Expand Down Expand Up @@ -1733,7 +1738,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let is_match = self
.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation.predicate, infer_projection)
.eq(DefineOpaqueTypes::No, obligation.predicate, infer_projection)
.is_ok_and(|InferOk { obligations, value: () }| {
self.evaluate_predicates_recursively(
TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
Expand Down Expand Up @@ -2533,7 +2538,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(
.eq(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes are currently breaking, changed them to match the new solver for now (but both solvers should probably instantiate eagerly to allow instantiating hr trait objects when upcasting)

we need tests for all of this though 😁

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consequence of this is what? That we can't do:

trait Supertrait<'a, 'b> {}
trait Subtrait: for<'a, 'b> Supertrait<'a, 'b> {}
let upcast: &dyn for<'a> Supertrait<'a, 'a> = todo!() as &dyn Subtrait;

Because of higher-ranked eq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, the consequence is that it fixes a previously unknown unsoundness 🤣

#![feature(trait_upcasting)]
trait Supertrait<'a, 'b> {
    fn cast(&self, x: &'a str) -> &'b str;
}

trait Subtrait<'a, 'b>: Supertrait<'a, 'b> {}

impl<'a> Supertrait<'a, 'a> for () {
    fn cast(&self, x: &'a str) -> &'a str {
        x
    }
}
impl<'a> Subtrait<'a, 'a> for () {}
fn unsound(x: &dyn for<'a> Subtrait<'a, 'a>) -> &dyn for<'a, 'b> Supertrait<'a, 'b> {
    x
}

fn transmute<'a, 'b>(x: &'a str) -> &'b str {
    unsound(&()).cast(x)
}

fn main() {
    let x;
    {
        let mut temp = String::from("hello there");
        x = transmute(temp.as_str());
    }
    println!("{x}");
}

DefineOpaqueTypes::No,
upcast_principal.map_bound(|trait_ref| {
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
Expand Down Expand Up @@ -2571,7 +2576,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, source_projection, target_projection)
.eq(DefineOpaqueTypes::No, source_projection, target_projection)
.map_err(|_| SelectionError::Unimplemented)?
.into_obligations(),
);
Expand Down Expand Up @@ -2615,9 +2620,15 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
obligation: &PolyTraitObligation<'tcx>,
poly_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, ()> {
let predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
poly_trait_ref,
);
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::No, obligation.predicate.to_poly_trait_ref(), poly_trait_ref)
.eq(DefineOpaqueTypes::No, predicate.trait_ref, trait_ref)
.map(|InferOk { obligations, .. }| obligations)
.map_err(|_| ())
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_trait_selection/src/traits/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ fn vtable_entries<'tcx>(
}

/// Find slot base for trait methods within vtable entries of another trait
// FIXME(@lcnr): This isn't a query, so why does it take a tuple as its argument.
pub(super) fn vtable_trait_first_method_offset<'tcx>(
tcx: TyCtxt<'tcx>,
key: (
Expand Down
1 change: 0 additions & 1 deletion src/tools/tidy/src/issues.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,6 @@
"ui/generic-associated-types/issue-92954.rs",
"ui/generic-associated-types/issue-93141.rs",
"ui/generic-associated-types/issue-93262.rs",
"ui/generic-associated-types/issue-93340.rs",
"ui/generic-associated-types/issue-93341.rs",
"ui/generic-associated-types/issue-93342.rs",
"ui/generic-associated-types/issue-93874.rs",
Expand Down
24 changes: 24 additions & 0 deletions tests/ui/associated-type-bounds/dedup-normalized-1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//@ check-pass

// We try to prove `T::Rigid: Into<?0>` and have 2 candidates from where-clauses:
//
// - `Into<String>`
// - `Into<<T::Rigid as Elaborate>::Assoc>`
//
// This causes ambiguity unless we normalize the alias in the second candidate
// to detect that they actually result in the same constraints.
trait Trait {
type Rigid: Elaborate<Assoc = String> + Into<String>;
}

trait Elaborate: Into<Self::Assoc> {
type Assoc;
}

fn impls<T: Into<U>, U>(_: T) {}

fn test<P: Trait>(rigid: P::Rigid) {
impls(rigid);
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// We try to prove `for<'b> T::Rigid: Bound<'b, ?0>` and have 2 candidates from where-clauses:
//
// - `for<'a> Bound<'a, String>`
// - `for<'a> Bound<'a, <T::Rigid as Elaborate>::Assoc>`
//
// This causes ambiguity unless we normalize the alias in the second candidate
// to detect that they actually result in the same constraints. We currently
// fail to detect that the constraints from these bounds are equal and error
// with ambiguity.
trait Bound<'a, U> {}

trait Trait {
type Rigid: Elaborate<Assoc = String> + for<'a> Bound<'a, String>;
}

trait Elaborate: for<'a> Bound<'a, Self::Assoc> {
type Assoc;
}

fn impls<T: for<'b> Bound<'b, U>, U>(_: T) {}

fn test<P: Trait>(rigid: P::Rigid) {
impls(rigid);
//~^ ERROR type annotations needed
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
error[E0283]: type annotations needed
--> $DIR/dedup-normalized-2-higher-ranked.rs:23:5
|
LL | impls(rigid);
| ^^^^^ cannot infer type of the type parameter `U` declared on the function `impls`
|
= note: cannot satisfy `for<'b> <P as Trait>::Rigid: Bound<'b, _>`
note: required by a bound in `impls`
--> $DIR/dedup-normalized-2-higher-ranked.rs:20:13
|
LL | fn impls<T: for<'b> Bound<'b, U>, U>(_: T) {}
| ^^^^^^^^^^^^^^^^^^^^ required by this bound in `impls`
help: consider specifying the generic arguments
|
LL | impls::<<P as Trait>::Rigid, U>(rigid);
| ++++++++++++++++++++++++++

error: aborting due to 1 previous error

For more information about this error, try `rustc --explain E0283`.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ LL | #![feature(return_type_notation)]
= note: see issue #109417 <https://github.com/rust-lang/rust/issues/109417> for more information
= note: `#[warn(incomplete_features)]` on by default

error[E0308]: mismatched types
error: implementation of `Send` is not general enough
--> $DIR/issue-110963-early.rs:14:5
|
LL | / spawn(async move {
Expand All @@ -16,17 +16,12 @@ LL | | if !hc.check().await {
LL | | log_health_check_failure().await;
LL | | }
LL | | });
| |______^ one type is more general than the other
| |______^ implementation of `Send` is not general enough
|
= note: expected trait `Send`
found trait `for<'a> Send`
note: the lifetime requirement is introduced here
--> $DIR/issue-110963-early.rs:34:17
|
LL | F: Future + Send + 'static,
| ^^^^
= note: `Send` would have to be implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'0>() }`, for any two lifetimes `'0` and `'1`...
= note: ...but `Send` is actually implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'2>() }`, for some specific lifetime `'2`

error[E0308]: mismatched types
error: implementation of `Send` is not general enough
--> $DIR/issue-110963-early.rs:14:5
|
LL | / spawn(async move {
Expand All @@ -35,17 +30,11 @@ LL | | if !hc.check().await {
LL | | log_health_check_failure().await;
LL | | }
LL | | });
| |______^ one type is more general than the other
|
= note: expected trait `Send`
found trait `for<'a> Send`
note: the lifetime requirement is introduced here
--> $DIR/issue-110963-early.rs:34:17
| |______^ implementation of `Send` is not general enough
|
LL | F: Future + Send + 'static,
| ^^^^
Comment on lines -42 to -46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we recover this easily?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looked into this a bit, we could extend NiceRegionError::report_trait_placeholder_mismatch to also point to the obligation cause, I do not want to do so in this PR however.

= note: `Send` would have to be implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'0>() }`, for any two lifetimes `'0` and `'1`...
= note: ...but `Send` is actually implemented for the type `impl Future<Output = bool> { <HC as HealthCheck>::check<'2>() }`, for some specific lifetime `'2`
= note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no`

error: aborting due to 2 previous errors; 1 warning emitted

For more information about this error, try `rustc --explain E0308`.
Loading
Loading