add caches to multiple type folders

This commit is contained in:
lcnr 2024-09-30 10:18:55 +02:00
parent 15ac698393
commit 13881f5404
8 changed files with 222 additions and 22 deletions

View file

@ -1,3 +1,4 @@
use rustc_type_ir::data_structures::DelayedMap;
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::visit::TypeVisitableExt;
@ -15,11 +16,12 @@ where
I: Interner,
{
delegate: &'a D,
cache: DelayedMap<I::Ty, I::Ty>,
}
impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
pub fn new(delegate: &'a D) -> Self {
EagerResolver { delegate }
EagerResolver { delegate, cache: Default::default() }
}
}
@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid),
_ => {
if t.has_infer() {
t.super_fold_with(self)
if let Some(&ty) = self.cache.get(&t) {
return ty;
}
let res = t.super_fold_with(self);
assert!(self.cache.insert(t, res));
res
} else {
t
}

View file

@ -3,7 +3,7 @@ use std::ops::ControlFlow;
use derive_where::derive_where;
#[cfg(feature = "nightly")]
use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
use rustc_type_ir::data_structures::ensure_sufficient_stack;
use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack};
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::relate::Relate;
@ -579,18 +579,16 @@ where
#[instrument(level = "trace", skip(self))]
pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
goal.predicate =
goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal);
self.nested_goals.normalizes_to_goals.push(goal);
}
#[instrument(level = "debug", skip(self))]
pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
goal.predicate =
goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal);
self.nested_goals.goals.push((source, goal));
}
@ -654,6 +652,7 @@ where
term: I::Term,
universe_of_term: ty::UniverseIndex,
delegate: &'a D,
cache: HashSet<I::Ty>,
}
impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> {
@ -671,6 +670,10 @@ where
{
type Result = ControlFlow<()>;
fn visit_ty(&mut self, t: I::Ty) -> Self::Result {
if self.cache.contains(&t) {
return ControlFlow::Continue(());
}
match t.kind() {
ty::Infer(ty::TyVar(vid)) => {
if let ty::TermKind::Ty(term) = self.term.kind() {
@ -683,17 +686,18 @@ where
}
}
self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())
self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?;
}
ty::Placeholder(p) => self.check_nameable(p.universe()),
ty::Placeholder(p) => self.check_nameable(p.universe())?,
_ => {
if t.has_non_region_infer() || t.has_placeholders() {
t.super_visit_with(self)
} else {
ControlFlow::Continue(())
t.super_visit_with(self)?
}
}
}
assert!(self.cache.insert(t));
ControlFlow::Continue(())
}
fn visit_const(&mut self, c: I::Const) -> Self::Result {
@ -728,6 +732,7 @@ where
delegate: self.delegate,
universe_of_term,
term: goal.predicate.term,
cache: Default::default(),
};
goal.predicate.alias.visit_with(&mut visitor).is_continue()
&& goal.param_env.visit_with(&mut visitor).is_continue()
@ -1015,6 +1020,17 @@ where
{
ecx: &'me mut EvalCtxt<'a, D>,
param_env: I::ParamEnv,
cache: HashMap<I::Ty, I::Ty>,
}
impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I>
where
D: SolverDelegate<Interner = I>,
I: Interner,
{
fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self {
ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() }
}
}
impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I>
@ -1041,7 +1057,16 @@ where
);
infer_ty
}
_ => ty.super_fold_with(self),
_ if ty.has_aliases() => {
if let Some(&entry) = self.cache.get(&ty) {
return entry;
}
let res = ty.super_fold_with(self);
assert!(self.cache.insert(ty, res).is_none());
res
}
_ => ty,
}
}