Handle discriminants in dataflow-const-prop.
This commit is contained in:
parent
cd3649b2a5
commit
9a6c04f5d0
6 changed files with 311 additions and 65 deletions
|
@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
|
|||
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
|
||||
use rustc_span::DUMMY_SP;
|
||||
use rustc_target::abi::Align;
|
||||
use rustc_target::abi::VariantIdx;
|
||||
|
||||
use crate::MirPass;
|
||||
|
||||
|
@ -30,6 +31,7 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
|
|||
|
||||
#[instrument(skip_all level = "debug")]
|
||||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
debug!(def_id = ?body.source.def_id());
|
||||
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
|
||||
debug!("aborted dataflow const prop due too many basic blocks");
|
||||
return;
|
||||
|
@ -63,14 +65,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
|
|||
}
|
||||
}
|
||||
|
||||
struct ConstAnalysis<'tcx> {
|
||||
struct ConstAnalysis<'a, 'tcx> {
|
||||
map: Map,
|
||||
tcx: TyCtxt<'tcx>,
|
||||
local_decls: &'a LocalDecls<'tcx>,
|
||||
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||
param_env: ty::ParamEnv<'tcx>,
|
||||
}
|
||||
|
||||
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
|
||||
impl<'tcx> ConstAnalysis<'_, 'tcx> {
|
||||
fn eval_disciminant(
|
||||
&self,
|
||||
enum_ty: Ty<'tcx>,
|
||||
variant_index: VariantIdx,
|
||||
) -> Option<ScalarTy<'tcx>> {
|
||||
if !enum_ty.is_enum() {
|
||||
return None;
|
||||
}
|
||||
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
|
||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
|
||||
let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
|
||||
Some(ScalarTy(discr_value, discr.ty))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
|
||||
type Value = FlatSet<ScalarTy<'tcx>>;
|
||||
|
||||
const NAME: &'static str = "ConstAnalysis";
|
||||
|
@ -79,6 +98,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
|
|||
&self.map
|
||||
}
|
||||
|
||||
fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
|
||||
match statement.kind {
|
||||
StatementKind::SetDiscriminant { box ref place, variant_index } => {
|
||||
state.flood_discr(place.as_ref(), &self.map);
|
||||
if self.map.find_discr(place.as_ref()).is_some() {
|
||||
let enum_ty = place.ty(self.local_decls, self.tcx).ty;
|
||||
if let Some(discr) = self.eval_disciminant(enum_ty, variant_index) {
|
||||
state.assign_discr(
|
||||
place.as_ref(),
|
||||
ValueOrPlace::Value(FlatSet::Elem(discr)),
|
||||
&self.map,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => self.super_statement(statement, state),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_assign(
|
||||
&self,
|
||||
target: Place<'tcx>,
|
||||
|
@ -87,17 +125,22 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
|
|||
) {
|
||||
match rvalue {
|
||||
Rvalue::Aggregate(kind, operands) => {
|
||||
let target = self.map().find(target.as_ref());
|
||||
if let Some(target) = target {
|
||||
state.flood_idx_with(target, self.map(), FlatSet::Bottom);
|
||||
let field_based = match **kind {
|
||||
AggregateKind::Tuple | AggregateKind::Closure(..) => true,
|
||||
AggregateKind::Adt(def_id, ..) => {
|
||||
matches!(self.tcx.def_kind(def_id), DefKind::Struct)
|
||||
state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
|
||||
if let Some(target_idx) = self.map().find(target.as_ref()) {
|
||||
let (variant_target, variant_index) = match **kind {
|
||||
AggregateKind::Tuple | AggregateKind::Closure(..) => {
|
||||
(Some(target_idx), None)
|
||||
}
|
||||
_ => false,
|
||||
AggregateKind::Adt(def_id, variant_index, ..) => {
|
||||
match self.tcx.def_kind(def_id) {
|
||||
DefKind::Struct => (Some(target_idx), None),
|
||||
DefKind::Enum => (Some(target_idx), Some(variant_index)),
|
||||
_ => (None, None),
|
||||
}
|
||||
}
|
||||
_ => (None, None),
|
||||
};
|
||||
if field_based {
|
||||
if let Some(target) = variant_target {
|
||||
for (field_index, operand) in operands.iter().enumerate() {
|
||||
if let Some(field) = self
|
||||
.map()
|
||||
|
@ -108,15 +151,20 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
|
|||
}
|
||||
}
|
||||
}
|
||||
if let Some(variant_index) = variant_index
|
||||
&& let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
|
||||
{
|
||||
let enum_ty = target.ty(self.local_decls, self.tcx).ty;
|
||||
if let Some(discr_val) = self.eval_disciminant(enum_ty, variant_index) {
|
||||
state.assign_idx(discr_idx, ValueOrPlace::Value(FlatSet::Elem(discr_val)), &self.map);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
|
||||
state.flood(target.as_ref(), self.map());
|
||||
|
||||
let target = self.map().find(target.as_ref());
|
||||
if let Some(target) = target {
|
||||
// We should not track any projections other than
|
||||
// what is overwritten below, but just in case...
|
||||
state.flood_idx(target, self.map());
|
||||
}
|
||||
|
||||
let value_target = target
|
||||
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
|
||||
|
@ -195,6 +243,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
|
|||
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
|
||||
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
|
||||
},
|
||||
Rvalue::Discriminant(place) => {
|
||||
ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
|
||||
}
|
||||
_ => self.super_rvalue(rvalue, state),
|
||||
}
|
||||
}
|
||||
|
@ -268,12 +319,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'tcx> ConstAnalysis<'tcx> {
|
||||
pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
|
||||
impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
|
||||
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
|
||||
let param_env = tcx.param_env(body.source.def_id());
|
||||
Self {
|
||||
map,
|
||||
tcx,
|
||||
local_decls: &body.local_decls,
|
||||
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
|
||||
param_env: param_env,
|
||||
}
|
||||
|
@ -466,6 +518,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
|
|||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
|
||||
match rvalue {
|
||||
Rvalue::Discriminant(place) => {
|
||||
match self.state.get_discr(place.as_ref(), self.visitor.map) {
|
||||
FlatSet::Top => (),
|
||||
FlatSet::Elem(value) => {
|
||||
self.visitor.before_effect.insert((location, *place), value);
|
||||
}
|
||||
FlatSet::Bottom => (),
|
||||
}
|
||||
}
|
||||
_ => self.super_rvalue(rvalue, location),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DummyMachine;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue