Auto merge of #100571 - cjgillot:mir-cost-visit, r=compiler-errors
Check projection types before inlining MIR Fixes https://github.com/rust-lang/rust/issues/100550 I'm very unhappy with this solution, having to duplicate MIR validation code, but at least it removes the ICE. r? `@compiler-errors`
This commit is contained in:
commit
4d45b0745a
3 changed files with 266 additions and 108 deletions
|
@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>(
|
||||||
|
|
||||||
// Normalize lifetimes away on both sides, then compare.
|
// Normalize lifetimes away on both sides, then compare.
|
||||||
let normalize = |ty: Ty<'tcx>| {
|
let normalize = |ty: Ty<'tcx>| {
|
||||||
tcx.normalize_erasing_regions(
|
let ty = ty.fold_with(&mut BottomUpFolder {
|
||||||
param_env,
|
tcx,
|
||||||
ty.fold_with(&mut BottomUpFolder {
|
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
|
||||||
tcx,
|
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
|
||||||
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
|
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
|
||||||
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
|
// since one may have an `impl SomeTrait for fn(&32)` and
|
||||||
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
|
// `impl SomeTrait for fn(&'static u32)` at the same time which
|
||||||
// since one may have an `impl SomeTrait for fn(&32)` and
|
// specify distinct values for Assoc. (See also #56105)
|
||||||
// `impl SomeTrait for fn(&'static u32)` at the same time which
|
lt_op: |_| tcx.lifetimes.re_erased,
|
||||||
// specify distinct values for Assoc. (See also #56105)
|
// Leave consts and types unchanged.
|
||||||
lt_op: |_| tcx.lifetimes.re_erased,
|
ct_op: |ct| ct,
|
||||||
// Leave consts and types unchanged.
|
ty_op: |ty| ty,
|
||||||
ct_op: |ct| ct,
|
});
|
||||||
ty_op: |ty| ty,
|
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty)
|
||||||
}),
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
|
tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyC
|
||||||
use rustc_session::config::OptLevel;
|
use rustc_session::config::OptLevel;
|
||||||
use rustc_span::def_id::DefId;
|
use rustc_span::def_id::DefId;
|
||||||
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
|
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
|
||||||
|
use rustc_target::abi::VariantIdx;
|
||||||
use rustc_target::spec::abi::Abi;
|
use rustc_target::spec::abi::Abi;
|
||||||
|
|
||||||
use super::simplify::{remove_dead_blocks, CfgSimplifier};
|
use super::simplify::{remove_dead_blocks, CfgSimplifier};
|
||||||
|
@ -414,118 +415,60 @@ impl<'tcx> Inliner<'tcx> {
|
||||||
debug!(" final inline threshold = {}", threshold);
|
debug!(" final inline threshold = {}", threshold);
|
||||||
|
|
||||||
// FIXME: Give a bonus to functions with only a single caller
|
// FIXME: Give a bonus to functions with only a single caller
|
||||||
let mut first_block = true;
|
let diverges = matches!(
|
||||||
let mut cost = 0;
|
callee_body.basic_blocks()[START_BLOCK].terminator().kind,
|
||||||
|
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
|
||||||
|
);
|
||||||
|
if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
|
||||||
|
return Err("callee diverges unconditionally");
|
||||||
|
}
|
||||||
|
|
||||||
// Traverse the MIR manually so we can account for the effects of
|
let mut checker = CostChecker {
|
||||||
// inlining on the CFG.
|
tcx: self.tcx,
|
||||||
|
param_env: self.param_env,
|
||||||
|
instance: callsite.callee,
|
||||||
|
callee_body,
|
||||||
|
cost: 0,
|
||||||
|
validation: Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
|
||||||
let mut work_list = vec![START_BLOCK];
|
let mut work_list = vec![START_BLOCK];
|
||||||
let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
|
let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
|
||||||
while let Some(bb) = work_list.pop() {
|
while let Some(bb) = work_list.pop() {
|
||||||
if !visited.insert(bb.index()) {
|
if !visited.insert(bb.index()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let blk = &callee_body.basic_blocks()[bb];
|
let blk = &callee_body.basic_blocks()[bb];
|
||||||
|
checker.visit_basic_block_data(bb, blk);
|
||||||
|
|
||||||
for stmt in &blk.statements {
|
|
||||||
// Don't count StorageLive/StorageDead in the inlining cost.
|
|
||||||
match stmt.kind {
|
|
||||||
StatementKind::StorageLive(_)
|
|
||||||
| StatementKind::StorageDead(_)
|
|
||||||
| StatementKind::Deinit(_)
|
|
||||||
| StatementKind::Nop => {}
|
|
||||||
_ => cost += INSTR_COST,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let term = blk.terminator();
|
let term = blk.terminator();
|
||||||
let mut is_drop = false;
|
if let TerminatorKind::Drop { ref place, target, unwind }
|
||||||
match term.kind {
|
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
|
||||||
TerminatorKind::Drop { ref place, target, unwind }
|
{
|
||||||
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
|
work_list.push(target);
|
||||||
is_drop = true;
|
|
||||||
work_list.push(target);
|
|
||||||
// If the place doesn't actually need dropping, treat it like
|
|
||||||
// a regular goto.
|
|
||||||
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
|
|
||||||
if ty.needs_drop(tcx, self.param_env) {
|
|
||||||
cost += CALL_PENALTY;
|
|
||||||
if let Some(unwind) = unwind {
|
|
||||||
cost += LANDINGPAD_PENALTY;
|
|
||||||
work_list.push(unwind);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cost += INSTR_COST;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
|
// If the place doesn't actually need dropping, treat it like a regular goto.
|
||||||
if first_block =>
|
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
|
||||||
{
|
if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
|
||||||
// If the function always diverges, don't inline
|
work_list.push(unwind);
|
||||||
// unless the cost is zero
|
|
||||||
threshold = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
|
|
||||||
if let ty::FnDef(def_id, _) =
|
|
||||||
*callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
|
|
||||||
{
|
|
||||||
// Don't give intrinsics the extra penalty for calls
|
|
||||||
if tcx.is_intrinsic(def_id) {
|
|
||||||
cost += INSTR_COST;
|
|
||||||
} else {
|
|
||||||
cost += CALL_PENALTY;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cost += CALL_PENALTY;
|
|
||||||
}
|
}
|
||||||
if cleanup.is_some() {
|
} else {
|
||||||
cost += LANDINGPAD_PENALTY;
|
work_list.extend(term.successors())
|
||||||
}
|
|
||||||
}
|
|
||||||
TerminatorKind::Assert { cleanup, .. } => {
|
|
||||||
cost += CALL_PENALTY;
|
|
||||||
|
|
||||||
if cleanup.is_some() {
|
|
||||||
cost += LANDINGPAD_PENALTY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TerminatorKind::Resume => cost += RESUME_PENALTY,
|
|
||||||
TerminatorKind::InlineAsm { cleanup, .. } => {
|
|
||||||
cost += INSTR_COST;
|
|
||||||
|
|
||||||
if cleanup.is_some() {
|
|
||||||
cost += LANDINGPAD_PENALTY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => cost += INSTR_COST,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !is_drop {
|
|
||||||
for succ in term.successors() {
|
|
||||||
work_list.push(succ);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
first_block = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count up the cost of local variables and temps, if we know the size
|
// Count up the cost of local variables and temps, if we know the size
|
||||||
// use that, otherwise we use a moderately-large dummy cost.
|
// use that, otherwise we use a moderately-large dummy cost.
|
||||||
|
|
||||||
let ptr_size = tcx.data_layout.pointer_size.bytes();
|
|
||||||
|
|
||||||
for v in callee_body.vars_and_temps_iter() {
|
for v in callee_body.vars_and_temps_iter() {
|
||||||
let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
|
checker.visit_local_decl(v, &callee_body.local_decls[v]);
|
||||||
// Cost of the var is the size in machine-words, if we know
|
|
||||||
// it.
|
|
||||||
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
|
|
||||||
cost += ((size + ptr_size - 1) / ptr_size) as usize;
|
|
||||||
} else {
|
|
||||||
cost += UNKNOWN_SIZE_COST;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Abort if type validation found anything fishy.
|
||||||
|
checker.validation?;
|
||||||
|
|
||||||
|
let cost = checker.cost;
|
||||||
if let InlineAttr::Always = callee_attrs.inline {
|
if let InlineAttr::Always = callee_attrs.inline {
|
||||||
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
|
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -799,6 +742,193 @@ fn type_size_of<'tcx>(
|
||||||
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
|
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify that the callee body is compatible with the caller.
|
||||||
|
///
|
||||||
|
/// This visitor mostly computes the inlining cost,
|
||||||
|
/// but also needs to verify that types match because of normalization failure.
|
||||||
|
struct CostChecker<'b, 'tcx> {
|
||||||
|
tcx: TyCtxt<'tcx>,
|
||||||
|
param_env: ParamEnv<'tcx>,
|
||||||
|
cost: usize,
|
||||||
|
callee_body: &'b Body<'tcx>,
|
||||||
|
instance: ty::Instance<'tcx>,
|
||||||
|
validation: Result<(), &'static str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
|
||||||
|
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
|
||||||
|
// Don't count StorageLive/StorageDead in the inlining cost.
|
||||||
|
match statement.kind {
|
||||||
|
StatementKind::StorageLive(_)
|
||||||
|
| StatementKind::StorageDead(_)
|
||||||
|
| StatementKind::Deinit(_)
|
||||||
|
| StatementKind::Nop => {}
|
||||||
|
_ => self.cost += INSTR_COST,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.super_statement(statement, location);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
|
||||||
|
let tcx = self.tcx;
|
||||||
|
match terminator.kind {
|
||||||
|
TerminatorKind::Drop { ref place, unwind, .. }
|
||||||
|
| TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
|
||||||
|
// If the place doesn't actually need dropping, treat it like a regular goto.
|
||||||
|
let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
|
||||||
|
if ty.needs_drop(tcx, self.param_env) {
|
||||||
|
self.cost += CALL_PENALTY;
|
||||||
|
if unwind.is_some() {
|
||||||
|
self.cost += LANDINGPAD_PENALTY;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.cost += INSTR_COST;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
|
||||||
|
let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
|
||||||
|
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
|
||||||
|
// Don't give intrinsics the extra penalty for calls
|
||||||
|
INSTR_COST
|
||||||
|
} else {
|
||||||
|
CALL_PENALTY
|
||||||
|
};
|
||||||
|
if cleanup.is_some() {
|
||||||
|
self.cost += LANDINGPAD_PENALTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TerminatorKind::Assert { cleanup, .. } => {
|
||||||
|
self.cost += CALL_PENALTY;
|
||||||
|
if cleanup.is_some() {
|
||||||
|
self.cost += LANDINGPAD_PENALTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TerminatorKind::Resume => self.cost += RESUME_PENALTY,
|
||||||
|
TerminatorKind::InlineAsm { cleanup, .. } => {
|
||||||
|
self.cost += INSTR_COST;
|
||||||
|
if cleanup.is_some() {
|
||||||
|
self.cost += LANDINGPAD_PENALTY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => self.cost += INSTR_COST,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.super_terminator(terminator, location);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count up the cost of local variables and temps, if we know the size
|
||||||
|
/// use that, otherwise we use a moderately-large dummy cost.
|
||||||
|
fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
|
||||||
|
let tcx = self.tcx;
|
||||||
|
let ptr_size = tcx.data_layout.pointer_size.bytes();
|
||||||
|
|
||||||
|
let ty = self.instance.subst_mir(tcx, &local_decl.ty);
|
||||||
|
// Cost of the var is the size in machine-words, if we know
|
||||||
|
// it.
|
||||||
|
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
|
||||||
|
self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
|
||||||
|
} else {
|
||||||
|
self.cost += UNKNOWN_SIZE_COST;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.super_local_decl(local, local_decl)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This method duplicates code from MIR validation in an attempt to detect type mismatches due
|
||||||
|
/// to normalization failure.
|
||||||
|
fn visit_projection_elem(
|
||||||
|
&mut self,
|
||||||
|
local: Local,
|
||||||
|
proj_base: &[PlaceElem<'tcx>],
|
||||||
|
elem: PlaceElem<'tcx>,
|
||||||
|
context: PlaceContext,
|
||||||
|
location: Location,
|
||||||
|
) {
|
||||||
|
if let ProjectionElem::Field(f, ty) = elem {
|
||||||
|
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
|
||||||
|
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
|
||||||
|
let check_equal = |this: &mut Self, f_ty| {
|
||||||
|
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
|
||||||
|
trace!(?ty, ?f_ty);
|
||||||
|
this.validation = Err("failed to normalize projection type");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let kind = match parent_ty.ty.kind() {
|
||||||
|
&ty::Opaque(def_id, substs) => {
|
||||||
|
self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
|
||||||
|
}
|
||||||
|
kind => kind,
|
||||||
|
};
|
||||||
|
|
||||||
|
match kind {
|
||||||
|
ty::Tuple(fields) => {
|
||||||
|
let Some(f_ty) = fields.get(f.as_usize()) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
check_equal(self, *f_ty);
|
||||||
|
}
|
||||||
|
ty::Adt(adt_def, substs) => {
|
||||||
|
let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
|
||||||
|
let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
check_equal(self, field.ty(self.tcx, substs));
|
||||||
|
}
|
||||||
|
ty::Closure(_, substs) => {
|
||||||
|
let substs = substs.as_closure();
|
||||||
|
let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
check_equal(self, f_ty);
|
||||||
|
}
|
||||||
|
&ty::Generator(def_id, substs, _) => {
|
||||||
|
let f_ty = if let Some(var) = parent_ty.variant_index {
|
||||||
|
let gen_body = if def_id == self.callee_body.source.def_id() {
|
||||||
|
self.callee_body
|
||||||
|
} else {
|
||||||
|
self.tcx.optimized_mir(def_id)
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(layout) = gen_body.generator_layout() else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(&local) = layout.variant_fields[var].get(f) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(&f_ty) = layout.field_tys.get(local) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
f_ty
|
||||||
|
} else {
|
||||||
|
let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
|
||||||
|
self.validation = Err("malformed MIR");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
f_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
check_equal(self, f_ty);
|
||||||
|
}
|
||||||
|
_ => self.validation = Err("malformed MIR"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.super_projection_elem(local, proj_base, elem, context, location);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Integrator.
|
* Integrator.
|
||||||
*
|
*
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
// This test verifies that we do not ICE due to MIR inlining in case of normalization failure
|
||||||
|
// in a projection.
|
||||||
|
//
|
||||||
|
// compile-flags: --crate-type lib -C opt-level=3
|
||||||
|
// build-pass
|
||||||
|
|
||||||
|
pub trait Trait {
|
||||||
|
type Associated;
|
||||||
|
}
|
||||||
|
impl<T> Trait for T {
|
||||||
|
type Associated = T;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Struct<T>(<T as Trait>::Associated);
|
||||||
|
|
||||||
|
pub fn foo<T>() -> Struct<T>
|
||||||
|
where
|
||||||
|
T: Trait,
|
||||||
|
{
|
||||||
|
bar()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn bar<T>() -> Struct<T> {
|
||||||
|
Struct(baz())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn baz<T>() -> T {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue