Support non-scalar constants.

This commit is contained in:
Camille GILLOT 2023-05-13 12:30:40 +00:00
parent 68c2f5ba0f
commit 6ad6b4381c
12 changed files with 259 additions and 22 deletions

View file

@ -3,18 +3,19 @@
//! Currently, this pass only propagates scalar values.
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{ConstValue, Scalar};
use rustc_middle::mir::visit::{MutVisitor, NonMutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{
Map, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
};
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
use rustc_span::DUMMY_SP;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{Align, FieldIdx, VariantIdx};
use crate::MirPass;
@ -111,6 +112,12 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
state: &mut State<Self::Value>,
) {
match rvalue {
Rvalue::Use(operand) => {
state.flood(target.as_ref(), self.map());
if let Some(target) = self.map.find(target.as_ref()) {
self.assign_operand(state, target, operand);
}
}
Rvalue::Aggregate(kind, operands) => {
// If we assign `target = Enum::Variant#0(operand)`,
// we must make sure that all `target as Variant#i` are `Top`.
@ -138,8 +145,7 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
variant_target_idx,
TrackElem::Field(FieldIdx::from_usize(field_index)),
) {
let result = self.handle_operand(operand, state);
state.insert_idx(field, result, self.map());
self.assign_operand(state, field, operand);
}
}
}
@ -330,6 +336,86 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
}
}
/// The caller must have flooded `place`.
fn assign_operand(
&self,
state: &mut State<FlatSet<ScalarInt>>,
place: PlaceIndex,
operand: &Operand<'tcx>,
) {
match operand {
Operand::Copy(rhs) | Operand::Move(rhs) => {
if let Some(rhs) = self.map.find(rhs.as_ref()) {
state.insert_place_idx(place, rhs, &self.map)
}
}
Operand::Constant(box constant) => {
if let Ok(constant) = self.ecx.eval_mir_constant(&constant.literal, None, None) {
self.assign_constant(state, place, constant, &[]);
}
}
}
}
/// The caller must have flooded `place`.
///
/// Perform: `place = operand.projection`.
#[instrument(level = "trace", skip(self, state))]
fn assign_constant(
&self,
state: &mut State<FlatSet<ScalarInt>>,
place: PlaceIndex,
mut operand: OpTy<'tcx>,
projection: &[PlaceElem<'tcx>],
) -> Option<!> {
for &(mut proj_elem) in projection {
if let PlaceElem::Index(index) = proj_elem {
if let FlatSet::Elem(index) = state.get(index.into(), &self.map)
&& let Ok(offset) = index.try_to_target_usize(self.tcx)
&& let Some(min_length) = offset.checked_add(1)
{
proj_elem = PlaceElem::ConstantIndex { offset, min_length, from_end: false };
} else {
return None;
}
}
operand = self.ecx.project(&operand, proj_elem).ok()?;
}
self.map.for_each_projection_value(
place,
operand,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let scalar = self.ecx.discriminant_for_variant(op.layout, variant).ok()?;
let discr_ty = op.layout.ty.discriminant_ty(self.tcx);
let layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
Some(ImmTy::from_scalar(scalar, layout).into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout =
self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(scalar)) = *imm
{
state.insert_value_idx(place, FlatSet::Elem(scalar), &self.map);
}
},
);
None
}
fn binary_op(
&self,
state: &mut State<FlatSet<ScalarInt>>,
@ -604,8 +690,16 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
type MemoryKind = !;
const PANIC_ON_ALLOC_FAIL: bool = true;
#[inline(always)]
fn cur_span(_ecx: &InterpCx<'mir, 'tcx, Self>) -> Span {
DUMMY_SP
}
#[inline(always)]
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment {
unimplemented!()
// We do not check for alignment to avoid having to carry an `Align`
// in `ConstValue::ByRef`.
CheckAlignment::No
}
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {