1
Fork 0

Auto merge of #107507 - BoxyUwU:deferred_projection_equality, r=lcnr

Implement `deferred_projection_equality` for erica solver

Somewhat of a revival of #96912. When relating projections now emit an `AliasEq` obligation instead of attempting to determine equality of projections that may not be as normalized as possible (i.e. because of lazy norm, or just containing inference variables that prevent us from resolving an impl). Only do this when the new solver is enabled
This commit is contained in:
bors 2023-02-11 05:46:24 +00:00
commit 1623ab0246
46 changed files with 585 additions and 163 deletions

View file

@ -12,7 +12,7 @@ use crate::infer::canonical::{
Canonical, CanonicalQueryResponse, CanonicalVarValues, Certainty, OriginalQueryValues,
QueryOutlivesConstraint, QueryRegionConstraints, QueryResponse,
};
use crate::infer::nll_relate::{NormalizationStrategy, TypeRelating, TypeRelatingDelegate};
use crate::infer::nll_relate::{TypeRelating, TypeRelatingDelegate};
use crate::infer::region_constraints::{Constraint, RegionConstraintData};
use crate::infer::{InferCtxt, InferOk, InferResult, NllRegionVariableOrigin};
use crate::traits::query::{Fallible, NoSolution};
@ -717,10 +717,6 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for QueryTypeRelatingDelegate<'_, 'tcx> {
});
}
fn normalization() -> NormalizationStrategy {
NormalizationStrategy::Eager
}
fn forbid_inference_vars() -> bool {
true
}

View file

@ -38,8 +38,8 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{
self, FallibleTypeFolder, InferConst, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable,
TypeVisitable,
self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
TypeSuperFoldable, TypeVisitable,
};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::{Span, DUMMY_SP};
@ -74,7 +74,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: Ty<'tcx>,
) -> RelateResult<'tcx, Ty<'tcx>>
where
R: TypeRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
let a_is_expected = relation.a_is_expected();
@ -122,6 +122,15 @@ impl<'tcx> InferCtxt<'tcx> {
Err(TypeError::Sorts(ty::relate::expected_found(relation, a, b)))
}
(ty::Alias(AliasKind::Projection, _), _) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(a.into(), b.into());
Ok(b)
}
(_, ty::Alias(AliasKind::Projection, _)) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(b.into(), a.into());
Ok(a)
}
_ => ty::relate::super_relate_tys(relation, a, b),
}
}
@ -133,7 +142,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>>
where
R: ConstEquateRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
debug!("{}.consts({:?}, {:?})", relation.tag(), a, b);
if a == b {
@ -169,7 +178,7 @@ impl<'tcx> InferCtxt<'tcx> {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(b);
}
@ -177,7 +186,7 @@ impl<'tcx> InferCtxt<'tcx> {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(a);
}
@ -435,32 +444,21 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
Ok(Generalization { ty, needs_wf })
}
pub fn add_const_equate_obligation(
pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.obligations.extend(obligations.into_iter());
}
pub fn register_predicates(
&mut self,
a_is_expected: bool,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
) {
let predicate = if a_is_expected {
ty::PredicateKind::ConstEquate(a, b)
} else {
ty::PredicateKind::ConstEquate(b, a)
};
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(predicate),
));
self.obligations.extend(obligations.into_iter().map(|to_pred| {
Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred)
}))
}
pub fn mark_ambiguous(&mut self) {
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(ty::PredicateKind::Ambiguous),
));
self.register_predicates([ty::Binder::dummy(ty::PredicateKind::Ambiguous)]);
}
}
@ -779,11 +777,42 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
}
}
pub trait ConstEquateRelation<'tcx>: TypeRelation<'tcx> {
pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
/// Register obligations that must hold in order for this relation to hold
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>);
/// Register predicates that must hold in order for this relation to hold. Uses
/// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
/// be used if control over the obligaton causes is required.
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
);
/// Register an obligation that both constants must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>);
fn register_const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() {
ty::PredicateKind::AliasEq(a.into(), b.into())
} else {
ty::PredicateKind::ConstEquate(a, b)
})]);
}
/// Register an obligation that both types must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq(
a.into(),
b.into(),
))]);
}
}
fn int_unification_error<'tcx>(

View file

@ -1,4 +1,6 @@
use super::combine::{CombineFields, ConstEquateRelation, RelationDir};
use crate::traits::PredicateObligations;
use super::combine::{CombineFields, ObligationEmittingRelation, RelationDir};
use super::Subtype;
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
@ -198,8 +200,15 @@ impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> {
}
}
impl<'tcx> ConstEquateRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}

View file

@ -1,12 +1,11 @@
//! Greatest lower bound. See [`lattice`].
use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;
use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};
@ -136,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
&self.fields.trace.cause
}
fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}
fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(v, a)?;
@ -152,8 +147,15 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
}
}
impl<'tcx> ConstEquateRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}

View file

@ -17,11 +17,12 @@
//!
//! [lattices]: https://en.wikipedia.org/wiki/Lattice_(order)
use super::combine::ObligationEmittingRelation;
use super::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use super::InferCtxt;
use crate::traits::{ObligationCause, PredicateObligation};
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
use crate::traits::ObligationCause;
use rustc_middle::ty::relate::RelateResult;
use rustc_middle::ty::TyVar;
use rustc_middle::ty::{self, Ty};
@ -30,13 +31,11 @@ use rustc_middle::ty::{self, Ty};
///
/// GLB moves "down" the lattice (to smaller values); LUB moves
/// "up" the lattice (to bigger values).
pub trait LatticeDir<'f, 'tcx>: TypeRelation<'tcx> {
pub trait LatticeDir<'f, 'tcx>: ObligationEmittingRelation<'tcx> {
fn infcx(&self) -> &'f InferCtxt<'tcx>;
fn cause(&self) -> &ObligationCause<'tcx>;
fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>);
fn define_opaque_types(&self) -> bool;
// Relates the type `v` to `a` and `b` such that `v` represents
@ -113,7 +112,7 @@ where
| (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
if this.define_opaque_types() && def_id.is_local() =>
{
this.add_obligations(
this.register_obligations(
infcx
.handle_opaque_type(a, b, this.a_is_expected(), this.cause(), this.param_env())?
.obligations,

View file

@ -1,12 +1,11 @@
//! Least upper bound. See [`lattice`].
use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;
use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};
@ -127,12 +126,6 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
}
}
impl<'tcx> ConstEquateRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
}
}
impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx, 'tcx> {
fn infcx(&self) -> &'infcx InferCtxt<'tcx> {
self.fields.infcx
@ -142,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
&self.fields.trace.cause
}
fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}
fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(a, v)?;
@ -157,3 +146,16 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
self.fields.define_opaque_types
}
}
impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations)
}
}

View file

@ -4,6 +4,7 @@ pub use self::LateBoundRegionConversionTime::*;
pub use self::RegionVariableOrigin::*;
pub use self::SubregionOrigin::*;
pub use self::ValuePairs::*;
pub use combine::ObligationEmittingRelation;
use self::opaque_types::OpaqueTypeStorage;
pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};

View file

@ -21,11 +21,10 @@
//! thing we relate in chalk are basically domain goals and their
//! constituents)
use crate::infer::combine::ConstEquateRelation;
use crate::infer::InferCtxt;
use crate::infer::{ConstVarValue, ConstVariableValue};
use crate::infer::{TypeVariableOrigin, TypeVariableOriginKind};
use crate::traits::{Obligation, PredicateObligation};
use crate::traits::{Obligation, PredicateObligations};
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::error::TypeError;
@ -36,11 +35,7 @@ use rustc_span::Span;
use std::fmt::Debug;
use std::ops::ControlFlow;
#[derive(PartialEq)]
pub enum NormalizationStrategy {
Lazy,
Eager,
}
use super::combine::ObligationEmittingRelation;
pub struct TypeRelating<'me, 'tcx, D>
where
@ -92,7 +87,7 @@ pub trait TypeRelatingDelegate<'tcx> {
info: ty::VarianceDiagInfo<'tcx>,
);
fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>);
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>);
/// Creates a new universe index. Used when instantiating placeholders.
fn create_next_universe(&mut self) -> ty::UniverseIndex;
@ -125,9 +120,6 @@ pub trait TypeRelatingDelegate<'tcx> {
/// relation stating that `'?0: 'a`).
fn generalize_existential(&mut self, universe: ty::UniverseIndex) -> ty::Region<'tcx>;
/// Define the normalization strategy to use, eager or lazy.
fn normalization() -> NormalizationStrategy;
/// Enables some optimizations if we do not expect inference variables
/// in the RHS of the relation.
fn forbid_inference_vars() -> bool;
@ -265,38 +257,6 @@ where
self.delegate.push_outlives(sup, sub, info);
}
/// Relate a projection type and some value type lazily. This will always
/// succeed, but we push an additional `ProjectionEq` goal depending
/// on the value type:
/// - if the value type is any type `T` which is not a projection, we push
/// `ProjectionEq(projection = T)`.
/// - if the value type is another projection `other_projection`, we create
/// a new inference variable `?U` and push the two goals
/// `ProjectionEq(projection = ?U)`, `ProjectionEq(other_projection = ?U)`.
fn relate_projection_ty(
&mut self,
projection_ty: ty::AliasTy<'tcx>,
value_ty: Ty<'tcx>,
) -> Ty<'tcx> {
use rustc_span::DUMMY_SP;
match *value_ty.kind() {
ty::Alias(ty::Projection, other_projection_ty) => {
let var = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::MiscVariable,
span: DUMMY_SP,
});
// FIXME(lazy-normalization): This will always ICE, because the recursive
// call will end up in the _ arm below.
self.relate_projection_ty(projection_ty, var);
self.relate_projection_ty(other_projection_ty, var);
var
}
_ => bug!("should never be invoked with eager normalization"),
}
}
/// Relate a type inference variable with a value type. This works
/// by creating a "generalization" G of the value where all the
/// lifetimes are replaced with fresh inference values. This
@ -335,12 +295,6 @@ where
return Ok(value_ty);
}
ty::Alias(ty::Projection, projection_ty)
if D::normalization() == NormalizationStrategy::Lazy =>
{
return Ok(self.relate_projection_ty(projection_ty, self.infcx.tcx.mk_ty_var(vid)));
}
_ => (),
}
@ -627,18 +581,6 @@ where
self.relate_opaques(a, b)
}
(&ty::Alias(ty::Projection, projection_ty), _)
if D::normalization() == NormalizationStrategy::Lazy =>
{
Ok(self.relate_projection_ty(projection_ty, b))
}
(_, &ty::Alias(ty::Projection, projection_ty))
if D::normalization() == NormalizationStrategy::Lazy =>
{
Ok(self.relate_projection_ty(projection_ty, a))
}
_ => {
debug!(?a, ?b, ?self.ambient_variance);
@ -813,17 +755,26 @@ where
}
}
impl<'tcx, D> ConstEquateRelation<'tcx> for TypeRelating<'_, 'tcx, D>
impl<'tcx, D> ObligationEmittingRelation<'tcx> for TypeRelating<'_, 'tcx, D>
where
D: TypeRelatingDelegate<'tcx>,
{
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.delegate.register_obligations(vec![Obligation::new(
self.tcx(),
ObligationCause::dummy(),
self.param_env(),
ty::Binder::dummy(ty::PredicateKind::ConstEquate(a, b)),
)]);
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.delegate.register_obligations(
obligations
.into_iter()
.map(|to_pred| {
Obligation::new(self.tcx(), ObligationCause::dummy(), self.param_env(), to_pred)
})
.collect(),
);
}
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.delegate.register_obligations(obligations);
}
}

View file

@ -21,6 +21,7 @@ pub fn explicit_outlives_bounds<'tcx>(
.filter_map(move |kind| match kind {
ty::PredicateKind::Clause(ty::Clause::Projection(..))
| ty::PredicateKind::Clause(ty::Clause::Trait(..))
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::Coerce(..)
| ty::PredicateKind::Subtype(..)
| ty::PredicateKind::WellFormed(..)

View file

@ -21,16 +21,28 @@ impl<'tcx> InferCtxt<'tcx> {
recursion_depth: usize,
obligations: &mut Vec<PredicateObligation<'tcx>>,
) -> Ty<'tcx> {
let def_id = projection_ty.def_id;
let ty_var = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: self.tcx.def_span(def_id),
});
let projection =
ty::Binder::dummy(ty::ProjectionPredicate { projection_ty, term: ty_var.into() });
let obligation =
Obligation::with_depth(self.tcx, cause, recursion_depth, param_env, projection);
obligations.push(obligation);
ty_var
if self.tcx.trait_solver_next() {
// FIXME(-Ztrait-solver=next): Instead of branching here,
// completely change the normalization routine with the new solver.
//
// The new solver correctly handles projection equality so this hack
// is not necessary. if re-enabled it should emit `PredicateKind::AliasEq`
// not `PredicateKind::Clause(Clause::Projection(..))` as in the new solver
// `Projection` is used as `normalizes-to` which will fail for `<T as Trait>::Assoc eq ?0`.
return projection_ty.to_ty(self.tcx);
} else {
let def_id = projection_ty.def_id;
let ty_var = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: self.tcx.def_span(def_id),
});
let projection = ty::Binder::dummy(ty::PredicateKind::Clause(ty::Clause::Projection(
ty::ProjectionPredicate { projection_ty, term: ty_var.into() },
)));
let obligation =
Obligation::with_depth(self.tcx, cause, recursion_depth, param_env, projection);
obligations.push(obligation);
ty_var
}
}
}

View file

@ -1,8 +1,7 @@
use super::combine::{CombineFields, RelationDir};
use super::SubregionOrigin;
use super::{ObligationEmittingRelation, SubregionOrigin};
use crate::infer::combine::ConstEquateRelation;
use crate::traits::Obligation;
use crate::traits::{Obligation, PredicateObligations};
use rustc_middle::ty::relate::{Cause, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::visit::TypeVisitable;
use rustc_middle::ty::TyVar;
@ -228,8 +227,15 @@ impl<'tcx> TypeRelation<'tcx> for Sub<'_, '_, 'tcx> {
}
}
impl<'tcx> ConstEquateRelation<'tcx> for Sub<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}

View file

@ -294,6 +294,9 @@ impl<'tcx> Elaborator<'tcx> {
// Nothing to elaborate
}
ty::PredicateKind::Ambiguous => {}
ty::PredicateKind::AliasEq(..) => {
// No
}
}
}
}