readd the provisional cache

This commit is contained in:
lcnr 2024-01-08 11:13:50 +01:00
parent eb4d7c7adf
commit 118453c7e1
6 changed files with 166 additions and 62 deletions

View file

@ -19,6 +19,7 @@
#![feature(control_flow_enum)]
#![feature(extract_if)]
#![feature(let_chains)]
#![feature(option_take_if)]
#![feature(if_let_guard)]
#![feature(never_type)]
#![feature(type_alias_impl_trait)]

View file

@ -171,7 +171,8 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
let mut candidates = vec![];
let last_eval_step = match self.evaluation.evaluation.kind {
inspect::CanonicalGoalEvaluationKind::Overflow
| inspect::CanonicalGoalEvaluationKind::CycleInStack => {
| inspect::CanonicalGoalEvaluationKind::CycleInStack
| inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit => {
warn!("unexpected root evaluation: {:?}", self.evaluation);
return vec![];
}

View file

@ -118,6 +118,7 @@ pub(in crate::solve) enum WipGoalEvaluationKind<'tcx> {
pub(in crate::solve) enum WipCanonicalGoalEvaluationKind<'tcx> {
Overflow,
CycleInStack,
ProvisionalCacheHit,
Interned { revisions: &'tcx [inspect::GoalEvaluationStep<'tcx>] },
}
@ -126,6 +127,7 @@ impl std::fmt::Debug for WipCanonicalGoalEvaluationKind<'_> {
match self {
Self::Overflow => write!(f, "Overflow"),
Self::CycleInStack => write!(f, "CycleInStack"),
Self::ProvisionalCacheHit => write!(f, "ProvisionalCacheHit"),
Self::Interned { revisions: _ } => f.debug_struct("Interned").finish_non_exhaustive(),
}
}
@ -151,6 +153,9 @@ impl<'tcx> WipCanonicalGoalEvaluation<'tcx> {
WipCanonicalGoalEvaluationKind::CycleInStack => {
inspect::CanonicalGoalEvaluationKind::CycleInStack
}
WipCanonicalGoalEvaluationKind::ProvisionalCacheHit => {
inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit
}
WipCanonicalGoalEvaluationKind::Interned { revisions } => {
inspect::CanonicalGoalEvaluationKind::Evaluation { revisions }
}

View file

@ -11,7 +11,6 @@ use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, Qu
use rustc_middle::ty;
use rustc_middle::ty::TyCtxt;
use rustc_session::Limit;
use std::collections::hash_map::Entry;
use std::mem;
rustc_index::newtype_index! {
@ -30,7 +29,7 @@ struct StackEntry<'tcx> {
///
/// If so, it must not be moved to the global cache. See
/// [SearchGraph::cycle_participants] for more details.
non_root_cycle_participant: bool,
non_root_cycle_participant: Option<StackDepth>,
encountered_overflow: bool,
has_been_used: bool,
@ -39,6 +38,34 @@ struct StackEntry<'tcx> {
provisional_result: Option<QueryResult<'tcx>>,
}
struct DetachedEntry<'tcx> {
/// The head of the smallest non-trivial cycle involving this entry.
///
/// Given the following rules, when proving `A` the head for
/// the provisional entry of `C` would be `B`.
///
/// A :- B
/// B :- C
/// C :- A + B + C
head: StackDepth,
result: QueryResult<'tcx>,
}
#[derive(Default)]
struct ProvisionalCacheEntry<'tcx> {
stack_depth: Option<StackDepth>,
with_inductive_stack: Option<DetachedEntry<'tcx>>,
with_coinductive_stack: Option<DetachedEntry<'tcx>>,
}
impl<'tcx> ProvisionalCacheEntry<'tcx> {
fn is_empty(&self) -> bool {
self.stack_depth.is_none()
&& self.with_inductive_stack.is_none()
&& self.with_coinductive_stack.is_none()
}
}
pub(super) struct SearchGraph<'tcx> {
mode: SolverMode,
local_overflow_limit: usize,
@ -46,7 +73,7 @@ pub(super) struct SearchGraph<'tcx> {
///
/// An element is *deeper* in the stack if its index is *lower*.
stack: IndexVec<StackDepth, StackEntry<'tcx>>,
stack_entries: FxHashMap<CanonicalInput<'tcx>, StackDepth>,
provisional_cache: FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
/// We put only the root goal of a coinductive cycle into the global cache.
///
/// If we were to use that result when later trying to prove another cycle
@ -63,7 +90,7 @@ impl<'tcx> SearchGraph<'tcx> {
mode,
local_overflow_limit: tcx.recursion_limit().0.checked_ilog2().unwrap_or(0) as usize,
stack: Default::default(),
stack_entries: Default::default(),
provisional_cache: Default::default(),
cycle_participants: Default::default(),
}
}
@ -93,7 +120,6 @@ impl<'tcx> SearchGraph<'tcx> {
/// would cause us to not track overflow and recursion depth correctly.
fn pop_stack(&mut self) -> StackEntry<'tcx> {
let elem = self.stack.pop().unwrap();
assert!(self.stack_entries.remove(&elem.input).is_some());
if let Some(last) = self.stack.raw.last_mut() {
last.reached_depth = last.reached_depth.max(elem.reached_depth);
last.encountered_overflow |= elem.encountered_overflow;
@ -114,7 +140,7 @@ impl<'tcx> SearchGraph<'tcx> {
pub(super) fn is_empty(&self) -> bool {
if self.stack.is_empty() {
debug_assert!(self.stack_entries.is_empty());
debug_assert!(self.provisional_cache.is_empty());
debug_assert!(self.cycle_participants.is_empty());
true
} else {
@ -156,6 +182,40 @@ impl<'tcx> SearchGraph<'tcx> {
}
}
fn stack_coinductive_from(
tcx: TyCtxt<'tcx>,
stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
head: StackDepth,
) -> bool {
stack.raw[head.index()..]
.iter()
.all(|entry| entry.input.value.goal.predicate.is_coinductive(tcx))
}
fn tag_cycle_participants(
stack: &mut IndexVec<StackDepth, StackEntry<'tcx>>,
cycle_participants: &mut FxHashSet<CanonicalInput<'tcx>>,
head: StackDepth,
) {
stack[head].has_been_used = true;
for entry in &mut stack.raw[head.index() + 1..] {
entry.non_root_cycle_participant = entry.non_root_cycle_participant.max(Some(head));
cycle_participants.insert(entry.input);
}
}
fn clear_dependent_provisional_results(
provisional_cache: &mut FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
head: StackDepth,
) {
#[allow(rustc::potential_query_instability)]
provisional_cache.retain(|_, entry| {
entry.with_coinductive_stack.take_if(|p| p.head == head);
entry.with_inductive_stack.take_if(|p| p.head == head);
!entry.is_empty()
});
}
/// Probably the most involved method of the whole solver.
///
/// Given some goal which is proven via the `prove_goal` closure, this
@ -210,23 +270,36 @@ impl<'tcx> SearchGraph<'tcx> {
return result;
}
// Check whether we're in a cycle.
match self.stack_entries.entry(input) {
// No entry, we push this goal on the stack and try to prove it.
Entry::Vacant(v) => {
let depth = self.stack.next_index();
let entry = StackEntry {
input,
available_depth,
reached_depth: depth,
non_root_cycle_participant: false,
encountered_overflow: false,
has_been_used: false,
provisional_result: None,
};
assert_eq!(self.stack.push(entry), depth);
v.insert(depth);
}
// Check whether the goal is in the provisional cache.
let cache_entry = self.provisional_cache.entry(input).or_default();
if let Some(with_coinductive_stack) = &mut cache_entry.with_coinductive_stack
&& Self::stack_coinductive_from(tcx, &self.stack, with_coinductive_stack.head)
{
// We have a nested goal which is already in the provisional cache, use
// its result.
inspect
.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
Self::tag_cycle_participants(
&mut self.stack,
&mut self.cycle_participants,
with_coinductive_stack.head,
);
return with_coinductive_stack.result;
} else if let Some(with_inductive_stack) = &mut cache_entry.with_inductive_stack
&& !Self::stack_coinductive_from(tcx, &self.stack, with_inductive_stack.head)
{
// We have a nested goal which is already in the provisional cache, use
// its result.
inspect
.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
Self::tag_cycle_participants(
&mut self.stack,
&mut self.cycle_participants,
with_inductive_stack.head,
);
return with_inductive_stack.result;
} else if let Some(stack_depth) = cache_entry.stack_depth {
debug!("encountered cycle with depth {stack_depth:?}");
// We have a nested goal which relies on a goal `root` deeper in the stack.
//
// We first store that we may have to reprove `root` in case the provisional
@ -236,40 +309,37 @@ impl<'tcx> SearchGraph<'tcx> {
//
// Finally we can return either the provisional response for that goal if we have a
// coinductive cycle or an ambiguous result if the cycle is inductive.
Entry::Occupied(entry) => {
inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
let stack_depth = *entry.get();
debug!("encountered cycle with depth {stack_depth:?}");
// We start by tagging all non-root cycle participants.
let participants = self.stack.raw.iter_mut().skip(stack_depth.as_usize() + 1);
for entry in participants {
entry.non_root_cycle_participant = true;
self.cycle_participants.insert(entry.input);
}
// If we're in a cycle, we have to retry proving the cycle head
// until we reach a fixpoint. It is not enough to simply retry the
// `root` goal of this cycle.
//
// See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
// for an example.
self.stack[stack_depth].has_been_used = true;
return if let Some(result) = self.stack[stack_depth].provisional_result {
result
inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
Self::tag_cycle_participants(
&mut self.stack,
&mut self.cycle_participants,
stack_depth,
);
return if let Some(result) = self.stack[stack_depth].provisional_result {
result
} else {
// If we don't have a provisional result yet we're in the first iteration,
// so we start with no constraints.
if Self::stack_coinductive_from(tcx, &self.stack, stack_depth) {
Self::response_no_constraints(tcx, input, Certainty::Yes)
} else {
// If we don't have a provisional result yet we're in the first iteration,
// so we start with no constraints.
let is_inductive = self.stack.raw[stack_depth.index()..]
.iter()
.any(|entry| !entry.input.value.goal.predicate.is_coinductive(tcx));
if is_inductive {
Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
} else {
Self::response_no_constraints(tcx, input, Certainty::Yes)
}
};
}
Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
}
};
} else {
// No entry, we push this goal on the stack and try to prove it.
let depth = self.stack.next_index();
let entry = StackEntry {
input,
available_depth,
reached_depth: depth,
non_root_cycle_participant: None,
encountered_overflow: false,
has_been_used: false,
provisional_result: None,
};
assert_eq!(self.stack.push(entry), depth);
cache_entry.stack_depth = Some(depth);
}
// This is for global caching, so we properly track query dependencies.
@ -285,11 +355,22 @@ impl<'tcx> SearchGraph<'tcx> {
for _ in 0..self.local_overflow_limit() {
let result = prove_goal(self, inspect);
// Check whether the current goal is the root of a cycle and whether
// we have to rerun because its provisional result differed from the
// final result.
// Check whether the current goal is the root of a cycle.
// If so, we have to retry proving the cycle head
// until its result reaches a fixpoint. We need to do so for
// all cycle heads, not only for the root.
//
// See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
// for an example.
let stack_entry = self.pop_stack();
debug_assert_eq!(stack_entry.input, input);
if stack_entry.has_been_used {
Self::clear_dependent_provisional_results(
&mut self.provisional_cache,
self.stack.next_index(),
);
}
if stack_entry.has_been_used
&& stack_entry.provisional_result.map_or(true, |r| r != result)
{
@ -299,7 +380,7 @@ impl<'tcx> SearchGraph<'tcx> {
provisional_result: Some(result),
..stack_entry
});
assert_eq!(self.stack_entries.insert(input, depth), None);
debug_assert_eq!(self.provisional_cache[&input].stack_depth, Some(depth));
} else {
return (stack_entry, result);
}
@ -307,6 +388,7 @@ impl<'tcx> SearchGraph<'tcx> {
debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
debug_assert!(!current_entry.has_been_used);
let result = Self::response_no_constraints(tcx, input, Certainty::OVERFLOW);
(current_entry, result)
});
@ -319,7 +401,17 @@ impl<'tcx> SearchGraph<'tcx> {
//
// It is not possible for any nested goal to depend on something deeper on the
// stack, as this would have also updated the depth of the current goal.
if !final_entry.non_root_cycle_participant {
if let Some(head) = final_entry.non_root_cycle_participant {
let coinductive_stack = Self::stack_coinductive_from(tcx, &self.stack, head);
let entry = self.provisional_cache.get_mut(&input).unwrap();
entry.stack_depth = None;
if coinductive_stack {
entry.with_coinductive_stack = Some(DetachedEntry { head, result });
} else {
entry.with_inductive_stack = Some(DetachedEntry { head, result });
}
} else {
// When encountering a cycle, both inductive and coinductive, we only
// move the root into the global cache. We also store all other cycle
// participants involved.
@ -328,6 +420,7 @@ impl<'tcx> SearchGraph<'tcx> {
// participant is on the stack. This is necessary to prevent unstable
// results. See the comment of `SearchGraph::cycle_participants` for
// more details.
self.provisional_cache.remove(&input);
let reached_depth = final_entry.reached_depth.as_usize() - self.stack.len();
let cycle_participants = mem::take(&mut self.cycle_participants);
self.global_cache(tcx).insert(