Auto merge of #132527 - DianQK:gvn-stmt-iter, r=oli-obk

gvn: Invalid dereferences for all non-local mutations

Fixes #132353.

This PR removes the computation value by traversing SSA locals through `for_each_assignment_mut`.

Because the `for_each_assignment_mut` traversal skips statements which have side effects, such as dereference assignments, the computation may be unsound. Instead of `for_each_assignment_mut`, we compute values by traversing in reverse postorder.

Because we compute and use the symbolic representation of values on the fly, I invalidate all old values when encountering a dereference assignment. The current approach does not prevent the optimization of a clone to a copy.

In the future, we may add an alias model, or dominance information for dereference assignments, or SSA form to help GVN.

r? cjgillot

cc `@jieyouxu` #132356
cc `@RalfJung` #133474
This commit is contained in:
bors 2025-04-03 19:17:33 +00:00
commit 00095b3da4
43 changed files with 567 additions and 574 deletions

View file

@ -3,14 +3,16 @@
//! MIR may contain repeated and/or redundant computations. The objective of this pass is to detect
//! such redundancies and re-use the already-computed result when possible.
//!
//! In a first pass, we compute a symbolic representation of values that are assigned to SSA
//! locals. This symbolic representation is defined by the `Value` enum. Each produced instance of
//! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values.
//!
//! From those assignments, we construct a mapping `VnIndex -> Vec<(Local, Location)>` of available
//! values, the locals in which they are stored, and the assignment location.
//!
//! In a second pass, we traverse all (non SSA) assignments `x = rvalue` and operands. For each
//! We traverse all assignments `x = rvalue` and operands.
//!
//! For each SSA one, we compute a symbolic representation of values that are assigned to SSA
//! locals. This symbolic representation is defined by the `Value` enum. Each produced instance of
//! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values.
//!
//! For each non-SSA
//! one, we compute the `VnIndex` of the rvalue. If this `VnIndex` is associated to a constant, we
//! replace the rvalue/operand by that constant. Otherwise, if there is an SSA local `y`
//! associated to this `VnIndex`, and if its definition location strictly dominates the assignment
@ -91,7 +93,7 @@ use rustc_const_eval::interpret::{
ImmTy, Immediate, InterpCx, MemPlaceMeta, MemoryKind, OpTy, Projectable, Scalar,
intern_const_alloc_for_constprop,
};
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::fx::{FxIndexSet, MutableValues};
use rustc_data_structures::graph::dominators::Dominators;
use rustc_hir::def::DefKind;
use rustc_index::bit_set::DenseBitSet;
@ -107,7 +109,7 @@ use rustc_span::def_id::DefId;
use smallvec::SmallVec;
use tracing::{debug, instrument, trace};
use crate::ssa::{AssignedValue, SsaLocals};
use crate::ssa::SsaLocals;
pub(super) struct GVN;
@ -126,31 +128,11 @@ impl<'tcx> crate::MirPass<'tcx> for GVN {
let dominators = body.basic_blocks.dominators().clone();
let mut state = VnState::new(tcx, body, typing_env, &ssa, dominators, &body.local_decls);
ssa.for_each_assignment_mut(
body.basic_blocks.as_mut_preserves_cfg(),
|local, value, location| {
let value = match value {
// We do not know anything of this assigned value.
AssignedValue::Arg | AssignedValue::Terminator => None,
// Try to get some insight.
AssignedValue::Rvalue(rvalue) => {
let value = state.simplify_rvalue(rvalue, location);
// FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark
// `local` as reusable if we have an exact type match.
if state.local_decls[local].ty != rvalue.ty(state.local_decls, tcx) {
return;
}
value
}
};
// `next_opaque` is `Some`, so `new_opaque` must return `Some`.
let value = value.or_else(|| state.new_opaque()).unwrap();
state.assign(local, value);
},
);
// Stop creating opaques during replacement as it is useless.
state.next_opaque = None;
for local in body.args_iter().filter(|&local| ssa.is_ssa(local)) {
let opaque = state.new_opaque();
state.assign(local, opaque);
}
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
for bb in reverse_postorder {
@ -250,14 +232,14 @@ struct VnState<'body, 'tcx> {
locals: IndexVec<Local, Option<VnIndex>>,
/// Locals that are assigned that value.
// This vector does not hold all the values of `VnIndex` that we create.
// It stops at the largest value created in the first phase of collecting assignments.
rev_locals: IndexVec<VnIndex, SmallVec<[Local; 1]>>,
values: FxIndexSet<Value<'tcx>>,
/// Values evaluated as constants if possible.
evaluated: IndexVec<VnIndex, Option<OpTy<'tcx>>>,
/// Counter to generate different values.
/// This is an option to stop creating opaques during replacement.
next_opaque: Option<usize>,
next_opaque: usize,
/// Cache the deref values.
derefs: Vec<VnIndex>,
/// Cache the value of the `unsized_locals` features, to avoid fetching it repeatedly in a loop.
feature_unsized_locals: bool,
ssa: &'body SsaLocals,
@ -289,7 +271,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
rev_locals: IndexVec::with_capacity(num_values),
values: FxIndexSet::with_capacity_and_hasher(num_values, Default::default()),
evaluated: IndexVec::with_capacity(num_values),
next_opaque: Some(1),
next_opaque: 1,
derefs: Vec::new(),
feature_unsized_locals: tcx.features().unsized_locals(),
ssa,
dominators,
@ -310,32 +293,31 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
let evaluated = self.eval_to_const(index);
let _index = self.evaluated.push(evaluated);
debug_assert_eq!(index, _index);
// No need to push to `rev_locals` if we finished listing assignments.
if self.next_opaque.is_some() {
let _index = self.rev_locals.push(SmallVec::new());
debug_assert_eq!(index, _index);
}
let _index = self.rev_locals.push(SmallVec::new());
debug_assert_eq!(index, _index);
}
index
}
fn next_opaque(&mut self) -> usize {
let next_opaque = self.next_opaque;
self.next_opaque += 1;
next_opaque
}
/// Create a new `Value` for which we have no information at all, except that it is distinct
/// from all the others.
#[instrument(level = "trace", skip(self), ret)]
fn new_opaque(&mut self) -> Option<VnIndex> {
let next_opaque = self.next_opaque.as_mut()?;
let value = Value::Opaque(*next_opaque);
*next_opaque += 1;
Some(self.insert(value))
fn new_opaque(&mut self) -> VnIndex {
let value = Value::Opaque(self.next_opaque());
self.insert(value)
}
/// Create a new `Value::Address` distinct from all the others.
#[instrument(level = "trace", skip(self), ret)]
fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> {
let next_opaque = self.next_opaque.as_mut()?;
let value = Value::Address { place, kind, provenance: *next_opaque };
*next_opaque += 1;
Some(self.insert(value))
fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> VnIndex {
let value = Value::Address { place, kind, provenance: self.next_opaque() };
self.insert(value)
}
fn get(&self, index: VnIndex) -> &Value<'tcx> {
@ -345,6 +327,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
/// Record that `local` is assigned `value`. `local` must be SSA.
#[instrument(level = "trace", skip(self))]
fn assign(&mut self, local: Local, value: VnIndex) {
debug_assert!(self.ssa.is_ssa(local));
self.locals[local] = Some(value);
// Only register the value if its type is `Sized`, as we will emit copies of it.
@ -355,21 +338,19 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}
fn insert_constant(&mut self, value: Const<'tcx>) -> Option<VnIndex> {
fn insert_constant(&mut self, value: Const<'tcx>) -> VnIndex {
let disambiguator = if value.is_deterministic() {
// The constant is deterministic, no need to disambiguate.
0
} else {
// Multiple mentions of this constant will yield different values,
// so assign a different `disambiguator` to ensure they do not get the same `VnIndex`.
let next_opaque = self.next_opaque.as_mut()?;
let disambiguator = *next_opaque;
*next_opaque += 1;
let disambiguator = self.next_opaque();
// `disambiguator: 0` means deterministic.
debug_assert_ne!(disambiguator, 0);
disambiguator
};
Some(self.insert(Value::Constant { value, disambiguator }))
self.insert(Value::Constant { value, disambiguator })
}
fn insert_bool(&mut self, flag: bool) -> VnIndex {
@ -390,6 +371,19 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::ZERO, values))
}
fn insert_deref(&mut self, value: VnIndex) -> VnIndex {
let value = self.insert(Value::Projection(value, ProjectionElem::Deref));
self.derefs.push(value);
value
}
fn invalidate_derefs(&mut self) {
for deref in std::mem::take(&mut self.derefs) {
let opaque = self.next_opaque();
*self.values.get_index_mut2(deref.index()).unwrap() = Value::Opaque(opaque);
}
}
#[instrument(level = "trace", skip(self), ret)]
fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
use Value::*;
@ -648,15 +642,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
let proj = match proj {
ProjectionElem::Deref => {
let ty = place.ty(self.local_decls, self.tcx).ty;
// unsound: https://github.com/rust-lang/rust/issues/130853
if self.tcx.sess.opts.unstable_opts.unsound_mir_opts
&& let Some(Mutability::Not) = ty.ref_mutability()
if let Some(Mutability::Not) = ty.ref_mutability()
&& let Some(pointee_ty) = ty.builtin_deref(true)
&& pointee_ty.is_freeze(self.tcx, self.typing_env())
{
// An immutable borrow `_x` always points to the same value for the
// lifetime of the borrow, so we can merge all instances of `*_x`.
ProjectionElem::Deref
return Some(self.insert_deref(value));
} else {
return None;
}
@ -830,7 +822,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
location: Location,
) -> Option<VnIndex> {
match *operand {
Operand::Constant(ref constant) => self.insert_constant(constant.const_),
Operand::Constant(ref constant) => Some(self.insert_constant(constant.const_)),
Operand::Copy(ref mut place) | Operand::Move(ref mut place) => {
let value = self.simplify_place_value(place, location)?;
if let Some(const_) = self.try_as_constant(value) {
@ -866,11 +858,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
Rvalue::Aggregate(..) => return self.simplify_aggregate(rvalue, location),
Rvalue::Ref(_, borrow_kind, ref mut place) => {
self.simplify_place_projection(place, location);
return self.new_pointer(*place, AddressKind::Ref(borrow_kind));
return Some(self.new_pointer(*place, AddressKind::Ref(borrow_kind)));
}
Rvalue::RawPtr(mutbl, ref mut place) => {
self.simplify_place_projection(place, location);
return self.new_pointer(*place, AddressKind::Address(mutbl));
return Some(self.new_pointer(*place, AddressKind::Address(mutbl)));
}
Rvalue::WrapUnsafeBinder(ref mut op, ty) => {
let value = self.simplify_operand(op, location)?;
@ -1034,7 +1026,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
if is_zst {
let ty = rvalue.ty(self.local_decls, tcx);
return self.insert_constant(Const::zero_sized(ty));
return Some(self.insert_constant(Const::zero_sized(ty)));
}
}
@ -1063,11 +1055,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
};
let fields: Option<Vec<_>> = field_ops
let mut fields: Vec<_> = field_ops
.iter_mut()
.map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
.map(|op| self.simplify_operand(op, location).unwrap_or_else(|| self.new_opaque()))
.collect();
let mut fields = fields?;
if let AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty } = &mut ty {
let mut was_updated = false;
@ -1107,9 +1098,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}
// unsound: https://github.com/rust-lang/rust/issues/132353
if tcx.sess.opts.unstable_opts.unsound_mir_opts
&& let AggregateTy::Def(_, _) = ty
if let AggregateTy::Def(_, _) = ty
&& let Some(value) =
self.simplify_aggregate_to_copy(rvalue, location, &fields, variant_index)
{
@ -1195,7 +1184,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
) if let ty::Slice(..) = to.builtin_deref(true).unwrap().kind()
&& let ty::Array(_, len) = from.builtin_deref(true).unwrap().kind() =>
{
return self.insert_constant(Const::Ty(self.tcx.types.usize, *len));
return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len)));
}
_ => Value::UnaryOp(op, arg_index),
};
@ -1391,7 +1380,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
if let CastKind::PointerCoercion(ReifyFnPointer | ClosureFnPointer(_), _) = kind {
// Each reification of a generic fn may get a different pointer.
// Do not try to merge them.
return self.new_opaque();
return Some(self.new_opaque());
}
let mut was_ever_updated = false;
@ -1507,7 +1496,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
// Trivial case: we are fetching a statically known length.
let place_ty = place.ty(self.local_decls, self.tcx).ty;
if let ty::Array(_, len) = place_ty.kind() {
return self.insert_constant(Const::Ty(self.tcx.types.usize, *len));
return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len)));
}
let mut inner = self.simplify_place_value(place, location)?;
@ -1529,7 +1518,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
&& let Some(to) = to.builtin_deref(true)
&& let ty::Slice(..) = to.kind()
{
return self.insert_constant(Const::Ty(self.tcx.types.usize, *len));
return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len)));
}
// Fallback: a symbolic `Len`.
@ -1739,42 +1728,71 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, 'tcx> {
self.tcx
}
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, location: Location) {
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
self.simplify_place_projection(place, location);
if context.is_mutating_use() && !place.projection.is_empty() {
// Non-local mutation maybe invalidate deref.
self.invalidate_derefs();
}
self.super_place(place, context, location);
}
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
self.simplify_operand(operand, location);
self.super_operand(operand, location);
}
fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
if let StatementKind::Assign(box (ref mut lhs, ref mut rvalue)) = stmt.kind {
self.simplify_place_projection(lhs, location);
// Do not try to simplify a constant, it's already in canonical shape.
if matches!(rvalue, Rvalue::Use(Operand::Constant(_))) {
return;
}
let value = lhs
.as_local()
.and_then(|local| self.locals[local])
.or_else(|| self.simplify_rvalue(rvalue, location));
let Some(value) = value else { return };
if let Some(const_) = self.try_as_constant(value) {
*rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
} else if let Some(local) = self.try_as_local(value, location)
&& *rvalue != Rvalue::Use(Operand::Move(local.into()))
let value = self.simplify_rvalue(rvalue, location);
let value = if let Some(local) = lhs.as_local()
&& self.ssa.is_ssa(local)
// FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark
// `local` as reusable if we have an exact type match.
&& self.local_decls[local].ty == rvalue.ty(self.local_decls, self.tcx)
{
*rvalue = Rvalue::Use(Operand::Copy(local.into()));
self.reused_locals.insert(local);
let value = value.unwrap_or_else(|| self.new_opaque());
self.assign(local, value);
Some(value)
} else {
value
};
if let Some(value) = value {
if let Some(const_) = self.try_as_constant(value) {
*rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
} else if let Some(local) = self.try_as_local(value, location)
&& *rvalue != Rvalue::Use(Operand::Move(local.into()))
{
*rvalue = Rvalue::Use(Operand::Copy(local.into()));
self.reused_locals.insert(local);
}
}
return;
}
self.super_statement(stmt, location);
}
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
if let Terminator { kind: TerminatorKind::Call { destination, .. }, .. } = terminator {
if let Some(local) = destination.as_local()
&& self.ssa.is_ssa(local)
{
let opaque = self.new_opaque();
self.assign(local, opaque);
}
}
// Function calls and ASM may invalidate (nested) derefs. We must handle them carefully.
// Currently, only preserving derefs for trivial terminators like SwitchInt and Goto.
let safe_to_preserve_derefs = matches!(
terminator.kind,
TerminatorKind::SwitchInt { .. } | TerminatorKind::Goto { .. }
);
if !safe_to_preserve_derefs {
self.invalidate_derefs();
}
self.super_terminator(terminator, location);
}
}
struct StorageRemover<'tcx> {

View file

@ -32,12 +32,6 @@ pub(super) struct SsaLocals {
borrowed_locals: DenseBitSet<Local>,
}
pub(super) enum AssignedValue<'a, 'tcx> {
Arg,
Rvalue(&'a mut Rvalue<'tcx>),
Terminator,
}
impl SsaLocals {
pub(super) fn new<'tcx>(
tcx: TyCtxt<'tcx>,
@ -152,38 +146,6 @@ impl SsaLocals {
})
}
pub(super) fn for_each_assignment_mut<'tcx>(
&self,
basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
mut f: impl FnMut(Local, AssignedValue<'_, 'tcx>, Location),
) {
for &local in &self.assignment_order {
match self.assignments[local] {
Set1::One(DefLocation::Argument) => f(
local,
AssignedValue::Arg,
Location { block: START_BLOCK, statement_index: 0 },
),
Set1::One(DefLocation::Assignment(loc)) => {
let bb = &mut basic_blocks[loc.block];
// `loc` must point to a direct assignment to `local`.
let stmt = &mut bb.statements[loc.statement_index];
let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else {
bug!()
};
assert_eq!(target.as_local(), Some(local));
f(local, AssignedValue::Rvalue(rvalue), loc)
}
Set1::One(DefLocation::CallReturn { call, .. }) => {
let bb = &mut basic_blocks[call];
let loc = Location { block: call, statement_index: bb.statements.len() };
f(local, AssignedValue::Terminator, loc)
}
_ => {}
}
}
}
/// Compute the equivalence classes for locals, based on copy statements.
///
/// The returned vector maps each local to the one it copies. In the following case: