1
Fork 0

Make search graph generic over interner

This commit is contained in:
Michael Goulet 2024-05-18 10:03:53 -04:00
parent d84b903754
commit 91685c0ef4
6 changed files with 75 additions and 54 deletions

View file

@ -233,6 +233,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
fn parent(self, def_id: Self::DefId) -> Self::DefId { fn parent(self, def_id: Self::DefId) -> Self::DefId {
self.parent(def_id) self.parent(def_id)
} }
fn recursion_limit(self) -> usize {
self.recursion_limit().0
}
} }
impl<'tcx> rustc_type_ir::inherent::Abi<TyCtxt<'tcx>> for abi::Abi { impl<'tcx> rustc_type_ir::inherent::Abi<TyCtxt<'tcx>> for abi::Abi {

View file

@ -37,7 +37,11 @@ pub struct Predicate<'tcx>(
pub(super) Interned<'tcx, WithCachedTypeInfo<ty::Binder<'tcx, PredicateKind<'tcx>>>>, pub(super) Interned<'tcx, WithCachedTypeInfo<ty::Binder<'tcx, PredicateKind<'tcx>>>>,
); );
impl<'tcx> rustc_type_ir::inherent::Predicate<TyCtxt<'tcx>> for Predicate<'tcx> {} impl<'tcx> rustc_type_ir::inherent::Predicate<TyCtxt<'tcx>> for Predicate<'tcx> {
fn is_coinductive(self, interner: TyCtxt<'tcx>) -> bool {
self.is_coinductive(interner)
}
}
impl<'tcx> rustc_type_ir::visit::Flags for Predicate<'tcx> { impl<'tcx> rustc_type_ir::visit::Flags for Predicate<'tcx> {
fn flags(&self) -> TypeFlags { fn flags(&self) -> TypeFlags {

View file

@ -85,7 +85,7 @@ pub struct EvalCtxt<'a, 'tcx> {
/// new placeholders to the caller. /// new placeholders to the caller.
pub(super) max_input_universe: ty::UniverseIndex, pub(super) max_input_universe: ty::UniverseIndex,
pub(super) search_graph: &'a mut SearchGraph<'tcx>, pub(super) search_graph: &'a mut SearchGraph<TyCtxt<'tcx>>,
nested_goals: NestedGoals<TyCtxt<'tcx>>, nested_goals: NestedGoals<TyCtxt<'tcx>>,
@ -225,7 +225,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
/// and registering opaques from the canonicalized input. /// and registering opaques from the canonicalized input.
fn enter_canonical<R>( fn enter_canonical<R>(
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
search_graph: &'a mut search_graph::SearchGraph<'tcx>, search_graph: &'a mut search_graph::SearchGraph<TyCtxt<'tcx>>,
canonical_input: CanonicalInput<'tcx>, canonical_input: CanonicalInput<'tcx>,
canonical_goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>, canonical_goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
f: impl FnOnce(&mut EvalCtxt<'_, 'tcx>, Goal<'tcx, ty::Predicate<'tcx>>) -> R, f: impl FnOnce(&mut EvalCtxt<'_, 'tcx>, Goal<'tcx, ty::Predicate<'tcx>>) -> R,
@ -287,7 +287,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
#[instrument(level = "debug", skip(tcx, search_graph, goal_evaluation), ret)] #[instrument(level = "debug", skip(tcx, search_graph, goal_evaluation), ret)]
fn evaluate_canonical_goal( fn evaluate_canonical_goal(
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
search_graph: &'a mut search_graph::SearchGraph<'tcx>, search_graph: &'a mut search_graph::SearchGraph<TyCtxt<'tcx>>,
canonical_input: CanonicalInput<'tcx>, canonical_input: CanonicalInput<'tcx>,
goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>, goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
) -> QueryResult<'tcx> { ) -> QueryResult<'tcx> {

View file

@ -1,18 +1,21 @@
use crate::solve::FIXPOINT_STEP_LIMIT; use std::mem;
use super::inspect; use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use super::inspect::ProofTreeBuilder;
use super::SolverMode;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::fx::FxHashSet;
use rustc_index::Idx; use rustc_index::Idx;
use rustc_index::IndexVec; use rustc_index::IndexVec;
use rustc_middle::dep_graph::dep_kinds; use rustc_middle::dep_graph::dep_kinds;
use rustc_middle::traits::solve::CacheData; use rustc_middle::traits::solve::CacheData;
use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, QueryResult}; use rustc_middle::traits::solve::EvaluationCache;
use rustc_middle::ty::TyCtxt; use rustc_middle::ty::TyCtxt;
use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult};
use rustc_session::Limit; use rustc_session::Limit;
use std::mem; use rustc_type_ir::inherent::*;
use rustc_type_ir::Interner;
use super::inspect;
use super::inspect::ProofTreeBuilder;
use super::SolverMode;
use crate::solve::FIXPOINT_STEP_LIMIT;
rustc_index::newtype_index! { rustc_index::newtype_index! {
#[orderable] #[orderable]
@ -30,9 +33,10 @@ bitflags::bitflags! {
} }
} }
#[derive(Debug)] #[derive(derivative::Derivative)]
struct StackEntry<'tcx> { #[derivative(Debug(bound = ""))]
input: CanonicalInput<'tcx>, struct StackEntry<I: Interner> {
input: CanonicalInput<I>,
available_depth: Limit, available_depth: Limit,
@ -53,11 +57,11 @@ struct StackEntry<'tcx> {
has_been_used: HasBeenUsed, has_been_used: HasBeenUsed,
/// Starts out as `None` and gets set when rerunning this /// Starts out as `None` and gets set when rerunning this
/// goal in case we encounter a cycle. /// goal in case we encounter a cycle.
provisional_result: Option<QueryResult<'tcx>>, provisional_result: Option<QueryResult<I>>,
} }
/// The provisional result for a goal which is not on the stack. /// The provisional result for a goal which is not on the stack.
struct DetachedEntry<'tcx> { struct DetachedEntry<I: Interner> {
/// The head of the smallest non-trivial cycle involving this entry. /// The head of the smallest non-trivial cycle involving this entry.
/// ///
/// Given the following rules, when proving `A` the head for /// Given the following rules, when proving `A` the head for
@ -68,7 +72,7 @@ struct DetachedEntry<'tcx> {
/// C :- A + B + C /// C :- A + B + C
/// ``` /// ```
head: StackDepth, head: StackDepth,
result: QueryResult<'tcx>, result: QueryResult<I>,
} }
/// Stores the stack depth of a currently evaluated goal *and* already /// Stores the stack depth of a currently evaluated goal *and* already
@ -83,14 +87,15 @@ struct DetachedEntry<'tcx> {
/// ///
/// The provisional cache can theoretically result in changes to the observable behavior, /// The provisional cache can theoretically result in changes to the observable behavior,
/// see tests/ui/traits/next-solver/cycles/provisional-cache-impacts-behavior.rs. /// see tests/ui/traits/next-solver/cycles/provisional-cache-impacts-behavior.rs.
#[derive(Default)] #[derive(derivative::Derivative)]
struct ProvisionalCacheEntry<'tcx> { #[derivative(Default(bound = ""))]
struct ProvisionalCacheEntry<I: Interner> {
stack_depth: Option<StackDepth>, stack_depth: Option<StackDepth>,
with_inductive_stack: Option<DetachedEntry<'tcx>>, with_inductive_stack: Option<DetachedEntry<I>>,
with_coinductive_stack: Option<DetachedEntry<'tcx>>, with_coinductive_stack: Option<DetachedEntry<I>>,
} }
impl<'tcx> ProvisionalCacheEntry<'tcx> { impl<I: Interner> ProvisionalCacheEntry<I> {
fn is_empty(&self) -> bool { fn is_empty(&self) -> bool {
self.stack_depth.is_none() self.stack_depth.is_none()
&& self.with_inductive_stack.is_none() && self.with_inductive_stack.is_none()
@ -98,13 +103,13 @@ impl<'tcx> ProvisionalCacheEntry<'tcx> {
} }
} }
pub(super) struct SearchGraph<'tcx> { pub(super) struct SearchGraph<I: Interner> {
mode: SolverMode, mode: SolverMode,
/// The stack of goals currently being computed. /// The stack of goals currently being computed.
/// ///
/// An element is *deeper* in the stack if its index is *lower*. /// An element is *deeper* in the stack if its index is *lower*.
stack: IndexVec<StackDepth, StackEntry<'tcx>>, stack: IndexVec<StackDepth, StackEntry<I>>,
provisional_cache: FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>, provisional_cache: FxHashMap<CanonicalInput<I>, ProvisionalCacheEntry<I>>,
/// We put only the root goal of a coinductive cycle into the global cache. /// We put only the root goal of a coinductive cycle into the global cache.
/// ///
/// If we were to use that result when later trying to prove another cycle /// If we were to use that result when later trying to prove another cycle
@ -112,11 +117,11 @@ pub(super) struct SearchGraph<'tcx> {
/// ///
/// See tests/ui/next-solver/coinduction/incompleteness-unstable-result.rs for /// See tests/ui/next-solver/coinduction/incompleteness-unstable-result.rs for
/// an example of where this is needed. /// an example of where this is needed.
cycle_participants: FxHashSet<CanonicalInput<'tcx>>, cycle_participants: FxHashSet<CanonicalInput<I>>,
} }
impl<'tcx> SearchGraph<'tcx> { impl<I: Interner> SearchGraph<I> {
pub(super) fn new(mode: SolverMode) -> SearchGraph<'tcx> { pub(super) fn new(mode: SolverMode) -> SearchGraph<I> {
Self { Self {
mode, mode,
stack: Default::default(), stack: Default::default(),
@ -144,7 +149,7 @@ impl<'tcx> SearchGraph<'tcx> {
/// ///
/// Directly popping from the stack instead of using this method /// Directly popping from the stack instead of using this method
/// would cause us to not track overflow and recursion depth correctly. /// would cause us to not track overflow and recursion depth correctly.
fn pop_stack(&mut self) -> StackEntry<'tcx> { fn pop_stack(&mut self) -> StackEntry<I> {
let elem = self.stack.pop().unwrap(); let elem = self.stack.pop().unwrap();
if let Some(last) = self.stack.raw.last_mut() { if let Some(last) = self.stack.raw.last_mut() {
last.reached_depth = last.reached_depth.max(elem.reached_depth); last.reached_depth = last.reached_depth.max(elem.reached_depth);
@ -153,17 +158,6 @@ impl<'tcx> SearchGraph<'tcx> {
elem elem
} }
/// The trait solver behavior is different for coherence
/// so we use a separate cache. Alternatively we could use
/// a single cache and share it between coherence and ordinary
/// trait solving.
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
match self.mode {
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
}
}
pub(super) fn is_empty(&self) -> bool { pub(super) fn is_empty(&self) -> bool {
if self.stack.is_empty() { if self.stack.is_empty() {
debug_assert!(self.provisional_cache.is_empty()); debug_assert!(self.provisional_cache.is_empty());
@ -181,8 +175,8 @@ impl<'tcx> SearchGraph<'tcx> {
/// the remaining depth of all nested goals to prevent hangs /// the remaining depth of all nested goals to prevent hangs
/// in case there is exponential blowup. /// in case there is exponential blowup.
fn allowed_depth_for_nested( fn allowed_depth_for_nested(
tcx: TyCtxt<'tcx>, tcx: I,
stack: &IndexVec<StackDepth, StackEntry<'tcx>>, stack: &IndexVec<StackDepth, StackEntry<I>>,
) -> Option<Limit> { ) -> Option<Limit> {
if let Some(last) = stack.raw.last() { if let Some(last) = stack.raw.last() {
if last.available_depth.0 == 0 { if last.available_depth.0 == 0 {
@ -195,13 +189,13 @@ impl<'tcx> SearchGraph<'tcx> {
Limit(last.available_depth.0 - 1) Limit(last.available_depth.0 - 1)
}) })
} else { } else {
Some(tcx.recursion_limit()) Some(Limit(tcx.recursion_limit()))
} }
} }
fn stack_coinductive_from( fn stack_coinductive_from(
tcx: TyCtxt<'tcx>, tcx: I,
stack: &IndexVec<StackDepth, StackEntry<'tcx>>, stack: &IndexVec<StackDepth, StackEntry<I>>,
head: StackDepth, head: StackDepth,
) -> bool { ) -> bool {
stack.raw[head.index()..] stack.raw[head.index()..]
@ -220,8 +214,8 @@ impl<'tcx> SearchGraph<'tcx> {
// we reach a fixpoint and all other cycle participants to make sure that // we reach a fixpoint and all other cycle participants to make sure that
// their result does not get moved to the global cache. // their result does not get moved to the global cache.
fn tag_cycle_participants( fn tag_cycle_participants(
stack: &mut IndexVec<StackDepth, StackEntry<'tcx>>, stack: &mut IndexVec<StackDepth, StackEntry<I>>,
cycle_participants: &mut FxHashSet<CanonicalInput<'tcx>>, cycle_participants: &mut FxHashSet<CanonicalInput<I>>,
usage_kind: HasBeenUsed, usage_kind: HasBeenUsed,
head: StackDepth, head: StackDepth,
) { ) {
@ -234,7 +228,7 @@ impl<'tcx> SearchGraph<'tcx> {
} }
fn clear_dependent_provisional_results( fn clear_dependent_provisional_results(
provisional_cache: &mut FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>, provisional_cache: &mut FxHashMap<CanonicalInput<I>, ProvisionalCacheEntry<I>>,
head: StackDepth, head: StackDepth,
) { ) {
#[allow(rustc::potential_query_instability)] #[allow(rustc::potential_query_instability)]
@ -244,6 +238,19 @@ impl<'tcx> SearchGraph<'tcx> {
!entry.is_empty() !entry.is_empty()
}); });
} }
}
impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
/// The trait solver behavior is different for coherence
/// so we use a separate cache. Alternatively we could use
/// a single cache and share it between coherence and ordinary
/// trait solving.
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
match self.mode {
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
}
}
/// Probably the most involved method of the whole solver. /// Probably the most involved method of the whole solver.
/// ///
@ -252,10 +259,13 @@ impl<'tcx> SearchGraph<'tcx> {
pub(super) fn with_new_goal( pub(super) fn with_new_goal(
&mut self, &mut self,
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
input: CanonicalInput<'tcx>, input: CanonicalInput<TyCtxt<'tcx>>,
inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>, inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<TyCtxt<'tcx>>) -> QueryResult<'tcx>, mut prove_goal: impl FnMut(
) -> QueryResult<'tcx> { &mut Self,
&mut ProofTreeBuilder<TyCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>>,
) -> QueryResult<TyCtxt<'tcx>> {
// Check for overflow. // Check for overflow.
let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else { let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
if let Some(last) = self.stack.raw.last_mut() { if let Some(last) = self.stack.raw.last_mut() {
@ -489,9 +499,9 @@ impl<'tcx> SearchGraph<'tcx> {
fn response_no_constraints( fn response_no_constraints(
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
goal: CanonicalInput<'tcx>, goal: CanonicalInput<TyCtxt<'tcx>>,
certainty: Certainty, certainty: Certainty,
) -> QueryResult<'tcx> { ) -> QueryResult<TyCtxt<'tcx>> {
Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty)) Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
} }
} }

View file

@ -96,6 +96,7 @@ pub trait GenericArgs<I: Interner<GenericArgs = Self>>:
pub trait Predicate<I: Interner<Predicate = Self>>: pub trait Predicate<I: Interner<Predicate = Self>>:
Copy + Debug + Hash + Eq + TypeSuperVisitable<I> + TypeSuperFoldable<I> + Flags Copy + Debug + Hash + Eq + TypeSuperVisitable<I> + TypeSuperFoldable<I> + Flags
{ {
fn is_coinductive(self, interner: I) -> bool;
} }
/// Common capabilities of placeholder kinds /// Common capabilities of placeholder kinds

View file

@ -124,6 +124,8 @@ pub trait Interner:
) -> Self::GenericArgs; ) -> Self::GenericArgs;
fn parent(self, def_id: Self::DefId) -> Self::DefId; fn parent(self, def_id: Self::DefId) -> Self::DefId;
fn recursion_limit(self) -> usize;
} }
/// Imagine you have a function `F: FnOnce(&[T]) -> R`, plus an iterator `iter` /// Imagine you have a function `F: FnOnce(&[T]) -> R`, plus an iterator `iter`