Instantiate binders when checking supertrait upcasting

This commit is contained in:
Michael Goulet 2024-09-26 23:20:59 -04:00
parent d4ee408afc
commit 4fb097a5de
4 changed files with 132 additions and 48 deletions

View file

@ -16,6 +16,7 @@ use rustc_hir::LangItem;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::BoundRegionConversionTime::{self, HigherRankedType};
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::infer::at::ToTrace;
use rustc_infer::infer::relate::TypeRelation;
use rustc_infer::traits::TraitObligation;
use rustc_middle::bug;
@ -44,7 +45,7 @@ use super::{
TraitQueryMode, const_evaluatable, project, util, wf,
};
use crate::error_reporting::InferCtxtErrorExt;
use crate::infer::{InferCtxt, InferCtxtExt, InferOk, TypeFreshener};
use crate::infer::{InferCtxt, InferOk, TypeFreshener};
use crate::solve::InferCtxtSelectExt as _;
use crate::traits::normalize::{normalize_with_depth, normalize_with_depth_to};
use crate::traits::project::{ProjectAndUnifyResult, ProjectionCacheKeyExt};
@ -2579,16 +2580,32 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
// Check that a_ty's supertrait (upcast_principal) is compatible
// with the target (b_ty).
ty::ExistentialPredicate::Trait(target_principal) => {
let hr_source_principal = upcast_principal.map_bound(|trait_ref| {
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
});
let hr_target_principal = bound.rebind(target_principal);
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(
DefineOpaqueTypes::Yes,
bound.rebind(target_principal),
upcast_principal.map_bound(|trait_ref| {
ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
}),
)
.enter_forall(hr_target_principal, |target_principal| {
let source_principal =
self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
hr_source_principal,
);
self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_principal,
hr_source_principal,
),
target_principal,
source_principal,
)
})
.map_err(|_| SelectionError::Unimplemented)?
.into_obligations(),
);
@ -2599,19 +2616,41 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
// return ambiguity. Otherwise, if exactly one matches, equate
// it with b_ty's projection.
ty::ExistentialPredicate::Projection(target_projection) => {
let target_projection = bound.rebind(target_projection);
let hr_target_projection = bound.rebind(target_projection);
let mut matching_projections =
a_data.projection_bounds().filter(|source_projection| {
a_data.projection_bounds().filter(|&hr_source_projection| {
// Eager normalization means that we can just use can_eq
// here instead of equating and processing obligations.
source_projection.item_def_id() == target_projection.item_def_id()
&& self.infcx.can_eq(
obligation.param_env,
*source_projection,
target_projection,
)
hr_source_projection.item_def_id() == hr_target_projection.item_def_id()
&& self.infcx.probe(|_| {
self.infcx
.enter_forall(hr_target_projection, |target_projection| {
let source_projection =
self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
hr_source_projection,
);
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_projection,
hr_source_projection,
),
target_projection,
source_projection,
)
})
.is_ok()
})
});
let Some(source_projection) = matching_projections.next() else {
let Some(hr_source_projection) = matching_projections.next() else {
return Err(SelectionError::Unimplemented);
};
if matching_projections.next().is_some() {
@ -2619,8 +2658,25 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
}
nested.extend(
self.infcx
.at(&obligation.cause, obligation.param_env)
.sup(DefineOpaqueTypes::Yes, target_projection, source_projection)
.enter_forall(hr_target_projection, |target_projection| {
let source_projection =
self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
hr_source_projection,
);
self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
DefineOpaqueTypes::Yes,
ToTrace::to_trace(
&obligation.cause,
true,
hr_target_projection,
hr_source_projection,
),
target_projection,
source_projection,
)
})
.map_err(|_| SelectionError::Unimplemented)?
.into_obligations(),
);