Pass correct param-env to error_implies

This commit is contained in:
Michael Goulet 2025-04-03 18:53:48 +00:00
parent d5b4c2e4f1
commit 64b58dd13b
5 changed files with 98 additions and 29 deletions

View file

@ -14,6 +14,7 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE, LocalDefId};
use rustc_hir::intravisit::Visitor;
use rustc_hir::{self as hir, LangItem, Node};
use rustc_infer::infer::{InferOk, TypeTrace};
use rustc_infer::traits::solve::Goal;
use rustc_middle::traits::SignatureMismatchData;
use rustc_middle::traits::select::OverflowError;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
@ -930,7 +931,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
)) = arg.kind
&& let Node::Pat(pat) = self.tcx.hir_node(*hir_id)
&& let Some((preds, guar)) = self.reported_trait_errors.borrow().get(&pat.span)
&& preds.contains(&obligation.predicate)
&& preds.contains(&obligation.as_goal())
{
return Err(*guar);
}
@ -1292,6 +1293,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
fn can_match_trait(
&self,
param_env: ty::ParamEnv<'tcx>,
goal: ty::TraitPredicate<'tcx>,
assumption: ty::PolyTraitPredicate<'tcx>,
) -> bool {
@ -1306,11 +1308,12 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
assumption,
);
self.can_eq(ty::ParamEnv::empty(), goal.trait_ref, trait_assumption.trait_ref)
self.can_eq(param_env, goal.trait_ref, trait_assumption.trait_ref)
}
fn can_match_projection(
&self,
param_env: ty::ParamEnv<'tcx>,
goal: ty::ProjectionPredicate<'tcx>,
assumption: ty::PolyProjectionPredicate<'tcx>,
) -> bool {
@ -1320,7 +1323,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
assumption,
);
let param_env = ty::ParamEnv::empty();
self.can_eq(param_env, goal.projection_term, assumption.projection_term)
&& self.can_eq(param_env, goal.term, assumption.term)
}
@ -1330,24 +1332,32 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
#[instrument(level = "debug", skip(self), ret)]
pub(super) fn error_implies(
&self,
cond: ty::Predicate<'tcx>,
error: ty::Predicate<'tcx>,
cond: Goal<'tcx, ty::Predicate<'tcx>>,
error: Goal<'tcx, ty::Predicate<'tcx>>,
) -> bool {
if cond == error {
return true;
}
if let Some(error) = error.as_trait_clause() {
// FIXME: We could be smarter about this, i.e. if cond's param-env is a
// subset of error's param-env. This only matters when binders will carry
// predicates though, and obviously only matters for error reporting.
if cond.param_env != error.param_env {
return false;
}
let param_env = error.param_env;
if let Some(error) = error.predicate.as_trait_clause() {
self.enter_forall(error, |error| {
elaborate(self.tcx, std::iter::once(cond))
elaborate(self.tcx, std::iter::once(cond.predicate))
.filter_map(|implied| implied.as_trait_clause())
.any(|implied| self.can_match_trait(error, implied))
.any(|implied| self.can_match_trait(param_env, error, implied))
})
} else if let Some(error) = error.as_projection_clause() {
} else if let Some(error) = error.predicate.as_projection_clause() {
self.enter_forall(error, |error| {
elaborate(self.tcx, std::iter::once(cond))
elaborate(self.tcx, std::iter::once(cond.predicate))
.filter_map(|implied| implied.as_projection_clause())
.any(|implied| self.can_match_projection(error, implied))
.any(|implied| self.can_match_projection(param_env, error, implied))
})
} else {
false

View file

@ -12,6 +12,7 @@ use rustc_errors::{Applicability, Diag, E0038, E0276, MultiSpan, struct_span_cod
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::intravisit::Visitor;
use rustc_hir::{self as hir, AmbigArg, LangItem};
use rustc_infer::traits::solve::Goal;
use rustc_infer::traits::{
DynCompatibilityViolation, Obligation, ObligationCause, ObligationCauseCode,
PredicateObligation, SelectionError,
@ -144,7 +145,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
#[derive(Debug)]
struct ErrorDescriptor<'tcx> {
predicate: ty::Predicate<'tcx>,
goal: Goal<'tcx, ty::Predicate<'tcx>>,
index: Option<usize>, // None if this is an old error
}
@ -152,15 +153,8 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
.reported_trait_errors
.borrow()
.iter()
.map(|(&span, predicates)| {
(
span,
predicates
.0
.iter()
.map(|&predicate| ErrorDescriptor { predicate, index: None })
.collect(),
)
.map(|(&span, goals)| {
(span, goals.0.iter().map(|&goal| ErrorDescriptor { goal, index: None }).collect())
})
.collect();
@ -186,10 +180,10 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
span = expn_data.call_site;
}
error_map.entry(span).or_default().push(ErrorDescriptor {
predicate: error.obligation.predicate,
index: Some(index),
});
error_map
.entry(span)
.or_default()
.push(ErrorDescriptor { goal: error.obligation.as_goal(), index: Some(index) });
}
// We do this in 2 passes because we want to display errors in order, though
@ -210,9 +204,9 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
continue;
}
if self.error_implies(error2.predicate, error.predicate)
if self.error_implies(error2.goal, error.goal)
&& !(error2.index >= error.index
&& self.error_implies(error.predicate, error2.predicate))
&& self.error_implies(error.goal, error2.goal))
{
info!("skipping {:?} (implied by {:?})", error, error2);
is_suppressed[index] = true;
@ -243,7 +237,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
.entry(span)
.or_insert_with(|| (vec![], guar))
.0
.push(error.obligation.predicate);
.push(error.obligation.as_goal());
}
}
}