1
Fork 0

Get rid of the redundant elaboration in middle

This commit is contained in:
Michael Goulet 2024-07-06 12:33:03 -04:00
parent 90423a7abb
commit 66eb346770
7 changed files with 40 additions and 100 deletions

View file

@ -7,7 +7,6 @@ pub mod select;
pub mod solve; pub mod solve;
pub mod specialization_graph; pub mod specialization_graph;
mod structural_impls; mod structural_impls;
pub mod util;
use crate::mir::ConstraintCategory; use crate::mir::ConstraintCategory;
use crate::ty::abstract_const::NotConstEvaluatable; use crate::ty::abstract_const::NotConstEvaluatable;

View file

@ -1,62 +0,0 @@
use rustc_data_structures::fx::FxHashSet;
use crate::ty::{Clause, PolyTraitRef, ToPolyTraitRef, TyCtxt, Upcast};
/// Given a [`PolyTraitRef`], get the [`Clause`]s implied by the trait's definition.
///
/// This only exists in `rustc_middle` because the more powerful elaborator depends on
/// `rustc_infer` for elaborating outlives bounds -- this should only be used for pretty
/// printing.
pub fn super_predicates_for_pretty_printing<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: PolyTraitRef<'tcx>,
) -> impl Iterator<Item = Clause<'tcx>> {
let clause = trait_ref.upcast(tcx);
Elaborator { tcx, visited: FxHashSet::from_iter([clause]), stack: vec![clause] }
}
/// Like [`super_predicates_for_pretty_printing`], except it only returns traits and filters out
/// all other [`Clause`]s.
pub fn supertraits_for_pretty_printing<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: PolyTraitRef<'tcx>,
) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
super_predicates_for_pretty_printing(tcx, trait_ref).filter_map(|clause| {
clause.as_trait_clause().map(|trait_clause| trait_clause.to_poly_trait_ref())
})
}
struct Elaborator<'tcx> {
tcx: TyCtxt<'tcx>,
visited: FxHashSet<Clause<'tcx>>,
stack: Vec<Clause<'tcx>>,
}
impl<'tcx> Elaborator<'tcx> {
fn elaborate(&mut self, trait_ref: PolyTraitRef<'tcx>) {
let super_predicates =
self.tcx.explicit_super_predicates_of(trait_ref.def_id()).predicates.iter().filter_map(
|&(pred, _)| {
let clause = pred.instantiate_supertrait(self.tcx, trait_ref);
self.visited.insert(clause).then_some(clause)
},
);
self.stack.extend(super_predicates);
}
}
impl<'tcx> Iterator for Elaborator<'tcx> {
type Item = Clause<'tcx>;
fn next(&mut self) -> Option<Clause<'tcx>> {
if let Some(clause) = self.stack.pop() {
if let Some(trait_clause) = clause.as_trait_clause() {
self.elaborate(trait_clause.to_poly_trait_ref());
}
Some(clause)
} else {
None
}
}
}

View file

@ -37,7 +37,7 @@ use crate::ty::{GenericArg, GenericArgs, GenericArgsRef};
use rustc_ast::{self as ast, attr}; use rustc_ast::{self as ast, attr};
use rustc_data_structures::defer; use rustc_data_structures::defer;
use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::intern::Interned; use rustc_data_structures::intern::Interned;
use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::profiling::SelfProfilerRef;
use rustc_data_structures::sharded::{IntoPointer, ShardedHashMap}; use rustc_data_structures::sharded::{IntoPointer, ShardedHashMap};
@ -532,10 +532,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.trait_def(trait_def_id).implement_via_object self.trait_def(trait_def_id).implement_via_object
} }
fn supertrait_def_ids(self, trait_def_id: DefId) -> impl IntoIterator<Item = DefId> {
self.supertrait_def_ids(trait_def_id)
}
fn delay_bug(self, msg: impl ToString) -> ErrorGuaranteed { fn delay_bug(self, msg: impl ToString) -> ErrorGuaranteed {
self.dcx().span_delayed_bug(DUMMY_SP, msg.to_string()) self.dcx().span_delayed_bug(DUMMY_SP, msg.to_string())
} }
@ -2495,25 +2491,7 @@ impl<'tcx> TyCtxt<'tcx> {
/// to identify which traits may define a given associated type to help avoid cycle errors, /// to identify which traits may define a given associated type to help avoid cycle errors,
/// and to make size estimates for vtable layout computation. /// and to make size estimates for vtable layout computation.
pub fn supertrait_def_ids(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx { pub fn supertrait_def_ids(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
let mut set = FxHashSet::default(); rustc_type_ir::elaborate::supertrait_def_ids(self, trait_def_id)
let mut stack = vec![trait_def_id];
set.insert(trait_def_id);
iter::from_fn(move || -> Option<DefId> {
let trait_did = stack.pop()?;
let generic_predicates = self.explicit_super_predicates_of(trait_did);
for (predicate, _) in generic_predicates.predicates {
if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}
Some(trait_did)
})
} }
/// Given a closure signature, returns an equivalent fn signature. Detuples /// Given a closure signature, returns an equivalent fn signature. Detuples

View file

@ -1,7 +1,6 @@
use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar}; use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
use crate::query::IntoQueryParam; use crate::query::IntoQueryParam;
use crate::query::Providers; use crate::query::Providers;
use crate::traits::util::{super_predicates_for_pretty_printing, supertraits_for_pretty_printing};
use crate::ty::GenericArgKind; use crate::ty::GenericArgKind;
use crate::ty::{ use crate::ty::{
ConstInt, Expr, ParamConst, ScalarInt, Term, TermKind, TypeFoldable, TypeSuperFoldable, ConstInt, Expr, ParamConst, ScalarInt, Term, TermKind, TypeFoldable, TypeSuperFoldable,
@ -23,6 +22,7 @@ use rustc_span::symbol::{kw, Ident, Symbol};
use rustc_span::FileNameDisplayPreference; use rustc_span::FileNameDisplayPreference;
use rustc_target::abi::Size; use rustc_target::abi::Size;
use rustc_target::spec::abi::Abi; use rustc_target::spec::abi::Abi;
use rustc_type_ir::{elaborate, Upcast as _};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::cell::Cell; use std::cell::Cell;
@ -1255,14 +1255,14 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
entry.has_fn_once = true; entry.has_fn_once = true;
return; return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) { } else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) {
let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref) let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait) .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap(); .unwrap();
fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref); fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
return; return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) { } else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) {
let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref) let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait) .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap(); .unwrap();
@ -1343,10 +1343,11 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let bound_principal_with_self = bound_principal let bound_principal_with_self = bound_principal
.with_self_ty(cx.tcx(), cx.tcx().types.trait_object_dummy_self); .with_self_ty(cx.tcx(), cx.tcx().types.trait_object_dummy_self);
let super_projections: Vec<_> = let clause: ty::Clause<'tcx> = bound_principal_with_self.upcast(cx.tcx());
super_predicates_for_pretty_printing(cx.tcx(), bound_principal_with_self) let super_projections: Vec<_> = elaborate::elaborate(cx.tcx(), [clause])
.filter_map(|clause| clause.as_projection_clause()) .filter_only_self()
.collect(); .filter_map(|clause| clause.as_projection_clause())
.collect();
let mut projections: Vec<_> = predicates let mut projections: Vec<_> = predicates
.projection_bounds() .projection_bounds()

View file

@ -6,7 +6,7 @@ use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*; use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem; use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _; use rustc_type_ir::visit::TypeVisitableExt as _;
use rustc_type_ir::{self as ty, Interner, TraitPredicate, Upcast as _}; use rustc_type_ir::{self as ty, elaborate, Interner, TraitPredicate, Upcast as _};
use tracing::{instrument, trace}; use tracing::{instrument, trace};
use crate::delegate::SolverDelegate; use crate::delegate::SolverDelegate;
@ -862,8 +862,7 @@ where
.auto_traits() .auto_traits()
.into_iter() .into_iter()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| { .chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
self.cx() elaborate::supertrait_def_ids(self.cx(), principal_def_id)
.supertrait_def_ids(principal_def_id)
.into_iter() .into_iter()
.filter(|def_id| self.cx().trait_is_auto(*def_id)) .filter(|def_id| self.cx().trait_is_auto(*def_id))
})) }))

View file

@ -229,6 +229,34 @@ impl<I: Interner, O: Elaboratable<I>> Iterator for Elaborator<I, O> {
// Supertrait iterator // Supertrait iterator
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
/// Computes the def-ids of the transitive supertraits of `trait_def_id`. This (intentionally)
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
/// to identify which traits may define a given associated type to help avoid cycle errors,
/// and to make size estimates for vtable layout computation.
pub fn supertrait_def_ids<I: Interner>(
cx: I,
trait_def_id: I::DefId,
) -> impl Iterator<Item = I::DefId> {
let mut set = HashSet::default();
let mut stack = vec![trait_def_id];
set.insert(trait_def_id);
std::iter::from_fn(move || {
let trait_def_id = stack.pop()?;
for (predicate, _) in cx.explicit_super_predicates_of(trait_def_id).iter_identity() {
if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}
Some(trait_def_id)
})
}
pub fn supertraits<I: Interner>( pub fn supertraits<I: Interner>(
tcx: I, tcx: I,
trait_ref: ty::Binder<I, ty::TraitRef<I>>, trait_ref: ty::Binder<I, ty::TraitRef<I>>,

View file

@ -253,9 +253,6 @@ pub trait Interner:
fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool; fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool;
fn supertrait_def_ids(self, trait_def_id: Self::DefId)
-> impl IntoIterator<Item = Self::DefId>;
fn delay_bug(self, msg: impl ToString) -> Self::ErrorGuaranteed; fn delay_bug(self, msg: impl ToString) -> Self::ErrorGuaranteed;
fn is_general_coroutine(self, coroutine_def_id: Self::DefId) -> bool; fn is_general_coroutine(self, coroutine_def_id: Self::DefId) -> bool;