Rollup merge of #107411 - cjgillot:dataflow-discriminant, r=oli-obk

Handle discriminant in DataflowConstProp

cc ``@jachris``
r? ``@JakobDegen``

This PR attempts to extend the DataflowConstProp pass to handle propagation of discriminants. We handle this by adding 2 new variants to `TrackElem`: `TrackElem::Variant` for enum variants and `TrackElem::Discriminant` for the enum discriminant pseudo-place.

The difficulty is that the enum discriminant and enum variants may alias each another. This is the issue of the `Option<NonZeroUsize>` test, which is the equivalent of https://github.com/rust-lang/unsafe-code-guidelines/issues/84 with a direct write.

To handle that, we generalize the flood process to flood all the potentially aliasing places. In particular:
- any write to `(PLACE as Variant)`, either direct or through a projection, floods `(PLACE as OtherVariant)` for all other variants and `discriminant(PLACE)`;
- `SetDiscriminant(PLACE)` floods `(PLACE as Variant)` for each variant.

This implies that flooding is not hierarchical any more, and that an assignment to a non-tracked place may need to flood a tracked place. This is handled by `for_each_aliasing_place` which generalizes `preorder_invoke`.

As we deaggregate enums by putting `SetDiscriminant` last, this allows to propagate the value of the discriminant.

This refactor will allow to make https://github.com/rust-lang/rust/pull/107009 able to handle discriminants too.
This commit is contained in:
Dylan DPC 2023-02-15 12:24:55 +05:30 committed by GitHub
commit c78e3c735a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 414 additions and 118 deletions

View file

@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
use rustc_target::abi::VariantIdx;
use crate::MirPass;
@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(def_id = ?body.source.def_id());
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
debug!("aborted dataflow const prop due too many basic blocks");
return;
}
// Decide which places to track during the analysis.
let map = Map::from_filter(tcx, body, Ty::is_scalar);
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
// map nodes is strongly correlated to the number of tracked places, this becomes more or
// less `O(n)` if we place a constant limit on the number of tracked places.
if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
debug!("aborted dataflow const prop due to too many tracked places");
return;
}
let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };
// Decide which places to track during the analysis.
let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);
// Perform the actual dataflow analysis.
let analysis = ConstAnalysis::new(tcx, body, map);
@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
}
}
struct ConstAnalysis<'tcx> {
struct ConstAnalysis<'a, 'tcx> {
map: Map,
tcx: TyCtxt<'tcx>,
local_decls: &'a LocalDecls<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
}
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
impl<'tcx> ConstAnalysis<'_, 'tcx> {
fn eval_discriminant(
&self,
enum_ty: Ty<'tcx>,
variant_index: VariantIdx,
) -> Option<ScalarTy<'tcx>> {
if !enum_ty.is_enum() {
return None;
}
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
Some(ScalarTy(discr_value, discr.ty))
}
}
impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
type Value = FlatSet<ScalarTy<'tcx>>;
const NAME: &'static str = "ConstAnalysis";
@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
&self.map
}
fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
match statement.kind {
StatementKind::SetDiscriminant { box ref place, variant_index } => {
state.flood_discr(place.as_ref(), &self.map);
if self.map.find_discr(place.as_ref()).is_some() {
let enum_ty = place.ty(self.local_decls, self.tcx).ty;
if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
state.assign_discr(
place.as_ref(),
ValueOrPlace::Value(FlatSet::Elem(discr)),
&self.map,
);
}
}
}
_ => self.super_statement(statement, state),
}
}
fn handle_assign(
&self,
target: Place<'tcx>,
@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
) {
match rvalue {
Rvalue::Aggregate(kind, operands) => {
let target = self.map().find(target.as_ref());
if let Some(target) = target {
state.flood_idx_with(target, self.map(), FlatSet::Bottom);
let field_based = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => true,
AggregateKind::Adt(def_id, ..) => {
matches!(self.tcx.def_kind(def_id), DefKind::Struct)
state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom);
if let Some(target_idx) = self.map().find(target.as_ref()) {
let (variant_target, variant_index) = match **kind {
AggregateKind::Tuple | AggregateKind::Closure(..) => {
(Some(target_idx), None)
}
_ => false,
AggregateKind::Adt(def_id, variant_index, ..) => {
match self.tcx.def_kind(def_id) {
DefKind::Struct => (Some(target_idx), None),
DefKind::Enum => (Some(target_idx), Some(variant_index)),
_ => (None, None),
}
}
_ => (None, None),
};
if field_based {
if let Some(target) = variant_target {
for (field_index, operand) in operands.iter().enumerate() {
if let Some(field) = self
.map()
.apply(target, TrackElem::Field(Field::from_usize(field_index)))
{
let result = self.handle_operand(operand, state);
state.assign_idx(field, result, self.map());
state.insert_idx(field, result, self.map());
}
}
}
if let Some(variant_index) = variant_index
&& let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
{
let enum_ty = target.ty(self.local_decls, self.tcx).ty;
if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
}
}
}
}
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
// Flood everything now, so we can use `insert_value_idx` directly later.
state.flood(target.as_ref(), self.map());
let target = self.map().find(target.as_ref());
if let Some(target) = target {
// We should not track any projections other than
// what is overwritten below, but just in case...
state.flood_idx(target, self.map());
}
let value_target = target
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
let (val, overflow) = self.binary_op(state, *op, left, right);
if let Some(value_target) = value_target {
state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
// We have flooded `target` earlier.
state.insert_value_idx(value_target, val, self.map());
}
if let Some(overflow_target) = overflow_target {
let overflow = match overflow {
@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
}
FlatSet::Bottom => FlatSet::Bottom,
};
state.assign_idx(
overflow_target,
ValueOrPlace::Value(overflow),
self.map(),
);
// We have flooded `target` earlier.
state.insert_value_idx(overflow_target, overflow, self.map());
}
}
}
@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
},
Rvalue::Discriminant(place) => {
ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
}
_ => self.super_rvalue(rvalue, state),
}
}
@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
}
}
impl<'tcx> ConstAnalysis<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
let param_env = tcx.param_env(body.source.def_id());
Self {
map,
tcx,
local_decls: &body.local_decls,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env: param_env,
}
@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
_ => (),
}
}
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
match rvalue {
Rvalue::Discriminant(place) => {
match self.state.get_discr(place.as_ref(), self.visitor.map) {
FlatSet::Top => (),
FlatSet::Elem(value) => {
self.visitor.before_effect.insert((location, *place), value);
}
FlatSet::Bottom => (),
}
}
_ => self.super_rvalue(rvalue, location),
}
}
}
struct DummyMachine;

View file

@ -1,5 +1,5 @@
use crate::MirPass;
use rustc_index::bit_set::BitSet;
use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
debug!(?replacements);
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
if !all_dead_locals.is_empty() {
for local in excluded.indices() {
excluded[local] |= all_dead_locals.contains(local);
}
excluded.raw.resize(body.local_decls.len(), false);
excluded.union(&all_dead_locals);
excluded = {
let mut growable = GrowableBitSet::from(excluded);
growable.ensure(body.local_decls.len());
growable.into()
};
} else {
break;
}
@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
/// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code.
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() {
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local);
}
}
@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>,
) -> BitSet<Local> {
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
for (local, replacements) in replacements.fragments.iter_enumerated() {
if replacements.is_some() {
all_dead_locals.insert(local);