Auto merge of #115796 - cjgillot:const-prop-rvalue, r=oli-obk
Generate aggregate constants in DataflowConstProp.
This commit is contained in:
commit
df871fbf05
23 changed files with 797 additions and 170 deletions
|
@ -2,13 +2,13 @@
|
|||
//!
|
||||
//! Currently, this pass only propagates scalar values.
|
||||
|
||||
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
|
||||
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_hir::def::DefKind;
|
||||
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
|
||||
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
|
||||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::layout::TyAndLayout;
|
||||
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
|
||||
use rustc_middle::ty::{self, Ty, TyCtxt};
|
||||
use rustc_mir_dataflow::value_analysis::{
|
||||
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
|
||||
|
@ -16,8 +16,9 @@ use rustc_mir_dataflow::value_analysis::{
|
|||
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
|
||||
use rustc_span::def_id::DefId;
|
||||
use rustc_span::DUMMY_SP;
|
||||
use rustc_target::abi::{FieldIdx, VariantIdx};
|
||||
use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
|
||||
|
||||
use crate::const_prop::throw_machine_stop_str;
|
||||
use crate::MirPass;
|
||||
|
||||
// These constants are somewhat random guesses and have not been optimized.
|
||||
|
@ -553,18 +554,153 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
|
|||
|
||||
fn try_make_constant(
|
||||
&self,
|
||||
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||
place: Place<'tcx>,
|
||||
state: &State<FlatSet<Scalar>>,
|
||||
map: &Map,
|
||||
) -> Option<Const<'tcx>> {
|
||||
let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
|
||||
return None;
|
||||
};
|
||||
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
|
||||
Some(Const::Val(ConstValue::Scalar(value.into()), ty))
|
||||
let layout = ecx.layout_of(ty).ok()?;
|
||||
|
||||
if layout.is_zst() {
|
||||
return Some(Const::zero_sized(ty));
|
||||
}
|
||||
|
||||
if layout.is_unsized() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let place = map.find(place.as_ref())?;
|
||||
if layout.abi.is_scalar()
|
||||
&& let Some(value) = propagatable_scalar(place, state, map)
|
||||
{
|
||||
return Some(Const::Val(ConstValue::Scalar(value), ty));
|
||||
}
|
||||
|
||||
if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
|
||||
let alloc_id = ecx
|
||||
.intern_with_temp_alloc(layout, |ecx, dest| {
|
||||
try_write_constant(ecx, dest, place, ty, state, map)
|
||||
})
|
||||
.ok()?;
|
||||
return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn propagatable_scalar(
|
||||
place: PlaceIndex,
|
||||
state: &State<FlatSet<Scalar>>,
|
||||
map: &Map,
|
||||
) -> Option<Scalar> {
|
||||
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
|
||||
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(ecx, state, map))]
|
||||
fn try_write_constant<'tcx>(
|
||||
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
|
||||
dest: &PlaceTy<'tcx>,
|
||||
place: PlaceIndex,
|
||||
ty: Ty<'tcx>,
|
||||
state: &State<FlatSet<Scalar>>,
|
||||
map: &Map,
|
||||
) -> InterpResult<'tcx> {
|
||||
let layout = ecx.layout_of(ty)?;
|
||||
|
||||
// Fast path for ZSTs.
|
||||
if layout.is_zst() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Fast path for scalars.
|
||||
if layout.abi.is_scalar()
|
||||
&& let Some(value) = propagatable_scalar(place, state, map)
|
||||
{
|
||||
return ecx.write_immediate(Immediate::Scalar(value), dest);
|
||||
}
|
||||
|
||||
match ty.kind() {
|
||||
// ZSTs. Nothing to do.
|
||||
ty::FnDef(..) => {}
|
||||
|
||||
// Those are scalars, must be handled above.
|
||||
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
|
||||
|
||||
ty::Tuple(elem_tys) => {
|
||||
for (i, elem) in elem_tys.iter().enumerate() {
|
||||
let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
|
||||
throw_machine_stop_str!("missing field in tuple")
|
||||
};
|
||||
let field_dest = ecx.project_field(dest, i)?;
|
||||
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
|
||||
}
|
||||
}
|
||||
|
||||
ty::Adt(def, args) => {
|
||||
if def.is_union() {
|
||||
throw_machine_stop_str!("cannot propagate unions")
|
||||
}
|
||||
|
||||
let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
|
||||
let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
|
||||
throw_machine_stop_str!("missing discriminant for enum")
|
||||
};
|
||||
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
|
||||
throw_machine_stop_str!("discriminant with provenance")
|
||||
};
|
||||
let discr_bits = discr.assert_bits(discr.size());
|
||||
let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
|
||||
throw_machine_stop_str!("illegal discriminant for enum")
|
||||
};
|
||||
let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
|
||||
throw_machine_stop_str!("missing variant for enum")
|
||||
};
|
||||
let variant_dest = ecx.project_downcast(dest, variant)?;
|
||||
(variant, def.variant(variant), variant_place, variant_dest)
|
||||
} else {
|
||||
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
|
||||
};
|
||||
|
||||
for (i, field) in variant_def.fields.iter_enumerated() {
|
||||
let ty = field.ty(*ecx.tcx, args);
|
||||
let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
|
||||
throw_machine_stop_str!("missing field in ADT")
|
||||
};
|
||||
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
|
||||
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
|
||||
}
|
||||
ecx.write_discriminant(variant_idx, dest)?;
|
||||
}
|
||||
|
||||
// Unsupported for now.
|
||||
ty::Array(_, _)
|
||||
|
||||
// Do not attempt to support indirection in constants.
|
||||
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
|
||||
|
||||
| ty::Never
|
||||
| ty::Foreign(..)
|
||||
| ty::Alias(..)
|
||||
| ty::Param(_)
|
||||
| ty::Bound(..)
|
||||
| ty::Placeholder(..)
|
||||
| ty::Closure(..)
|
||||
| ty::Coroutine(..)
|
||||
| ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),
|
||||
|
||||
ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<'mir, 'tcx>
|
||||
ResultsVisitor<'mir, 'tcx, Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>>
|
||||
for Collector<'tcx, '_>
|
||||
|
@ -580,8 +716,13 @@ impl<'mir, 'tcx>
|
|||
) {
|
||||
match &statement.kind {
|
||||
StatementKind::Assign(box (_, rvalue)) => {
|
||||
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
|
||||
.visit_rvalue(rvalue, location);
|
||||
OperandCollector {
|
||||
state,
|
||||
visitor: self,
|
||||
ecx: &mut results.analysis.0.ecx,
|
||||
map: &results.analysis.0.map,
|
||||
}
|
||||
.visit_rvalue(rvalue, location);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
@ -599,7 +740,12 @@ impl<'mir, 'tcx>
|
|||
// Don't overwrite the assignment if it already uses a constant (to keep the span).
|
||||
}
|
||||
StatementKind::Assign(box (place, _)) => {
|
||||
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
|
||||
if let Some(value) = self.try_make_constant(
|
||||
&mut results.analysis.0.ecx,
|
||||
place,
|
||||
state,
|
||||
&results.analysis.0.map,
|
||||
) {
|
||||
self.patch.assignments.insert(location, value);
|
||||
}
|
||||
}
|
||||
|
@ -614,8 +760,13 @@ impl<'mir, 'tcx>
|
|||
terminator: &'mir Terminator<'tcx>,
|
||||
location: Location,
|
||||
) {
|
||||
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
|
||||
.visit_terminator(terminator, location);
|
||||
OperandCollector {
|
||||
state,
|
||||
visitor: self,
|
||||
ecx: &mut results.analysis.0.ecx,
|
||||
map: &results.analysis.0.map,
|
||||
}
|
||||
.visit_terminator(terminator, location);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -670,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
|
|||
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
|
||||
state: &'a State<FlatSet<Scalar>>,
|
||||
visitor: &'a mut Collector<'tcx, 'locals>,
|
||||
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||
map: &'map Map,
|
||||
}
|
||||
|
||||
|
@ -682,7 +834,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
|
|||
location: Location,
|
||||
) {
|
||||
if let PlaceElem::Index(local) = elem
|
||||
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
|
||||
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
|
||||
{
|
||||
self.visitor.patch.before_effect.insert((location, local.into()), value);
|
||||
}
|
||||
|
@ -690,7 +842,9 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
|
|||
|
||||
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
|
||||
if let Some(place) = operand.place() {
|
||||
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
|
||||
if let Some(value) =
|
||||
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
|
||||
{
|
||||
self.visitor.patch.before_effect.insert((location, place), value);
|
||||
} else if !place.projection.is_empty() {
|
||||
// Try to propagate into `Index` projections.
|
||||
|
@ -713,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
|
|||
}
|
||||
|
||||
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
|
||||
unimplemented!()
|
||||
false
|
||||
}
|
||||
|
||||
fn before_access_global(
|
||||
|
@ -725,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
|
|||
is_write: bool,
|
||||
) -> InterpResult<'tcx> {
|
||||
if is_write {
|
||||
crate::const_prop::throw_machine_stop_str!("can't write to global");
|
||||
throw_machine_stop_str!("can't write to global");
|
||||
}
|
||||
|
||||
// If the static allocation is mutable, then we can't const prop it as its content
|
||||
// might be different at runtime.
|
||||
if alloc.inner().mutability.is_mut() {
|
||||
crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
|
||||
throw_machine_stop_str!("can't access mutable globals in ConstProp");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -781,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
|
|||
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
|
||||
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
|
||||
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
|
||||
crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
|
||||
throw_machine_stop_str!("can't do pointer arithmetic");
|
||||
}
|
||||
|
||||
fn expose_ptr(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue