1
Fork 0

Rollup merge of #107411 - cjgillot:dataflow-discriminant, r=oli-obk

Handle discriminant in DataflowConstProp

cc ``@jachris``
r? ``@JakobDegen``

This PR attempts to extend the DataflowConstProp pass to handle propagation of discriminants. We handle this by adding 2 new variants to `TrackElem`: `TrackElem::Variant` for enum variants and `TrackElem::Discriminant` for the enum discriminant pseudo-place.

The difficulty is that the enum discriminant and enum variants may alias each another. This is the issue of the `Option<NonZeroUsize>` test, which is the equivalent of https://github.com/rust-lang/unsafe-code-guidelines/issues/84 with a direct write.

To handle that, we generalize the flood process to flood all the potentially aliasing places. In particular:
- any write to `(PLACE as Variant)`, either direct or through a projection, floods `(PLACE as OtherVariant)` for all other variants and `discriminant(PLACE)`;
- `SetDiscriminant(PLACE)` floods `(PLACE as Variant)` for each variant.

This implies that flooding is not hierarchical any more, and that an assignment to a non-tracked place may need to flood a tracked place. This is handled by `for_each_aliasing_place` which generalizes `preorder_invoke`.

As we deaggregate enums by putting `SetDiscriminant` last, this allows to propagate the value of the discriminant.

This refactor will allow to make https://github.com/rust-lang/rust/pull/107009 able to handle discriminants too.
This commit is contained in:
Dylan DPC 2023-02-15 12:24:55 +05:30 committed by GitHub
commit c78e3c735a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 414 additions and 118 deletions

View file

@ -121,7 +121,9 @@ where
// for now. See discussion on [#61069].
//
// [#61069]: https://github.com/rust-lang/rust/pull/61069
self.trans.gen(dropped_place.local);
if !dropped_place.is_indirect() {
self.trans.gen(dropped_place.local);
}
}
TerminatorKind::Abort

View file

@ -1,6 +1,7 @@
#![feature(associated_type_defaults)]
#![feature(box_patterns)]
#![feature(exact_size_is_empty)]
#![feature(let_chains)]
#![feature(min_specialization)]
#![feature(once_cell)]
#![feature(stmt_expr_attributes)]

View file

@ -24,7 +24,7 @@
//! - The bottom state denotes uninitialized memory. Because we are only doing a sound approximation
//! of the actual execution, we can also use this state for places where access would be UB.
//!
//! - The assignment logic in `State::assign_place_idx` assumes that the places are non-overlapping,
//! - The assignment logic in `State::insert_place_idx` assumes that the places are non-overlapping,
//! or identical. Note that this refers to place expressions, not memory locations.
//!
//! - Currently, places that have their reference taken cannot be tracked. Although this would be
@ -35,6 +35,7 @@
use std::fmt::{Debug, Formatter};
use rustc_data_structures::fx::FxHashMap;
use rustc_index::bit_set::BitSet;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
@ -64,10 +65,8 @@ pub trait ValueAnalysis<'tcx> {
StatementKind::Assign(box (place, rvalue)) => {
self.handle_assign(*place, rvalue, state);
}
StatementKind::SetDiscriminant { .. } => {
// Could treat this as writing a constant to a pseudo-place.
// But discriminants are currently not tracked, so we do nothing.
// Related: https://github.com/rust-lang/unsafe-code-guidelines/issues/84
StatementKind::SetDiscriminant { box ref place, .. } => {
state.flood_discr(place.as_ref(), self.map());
}
StatementKind::Intrinsic(box intrinsic) => {
self.handle_intrinsic(intrinsic, state);
@ -446,26 +445,51 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
}
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
if let Some(root) = map.find(place) {
self.flood_idx_with(root, map, value);
}
}
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_with(place, map, V::top())
}
pub fn flood_idx_with(&mut self, place: PlaceIndex, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.preorder_invoke(place, &mut |place| {
map.for_each_aliasing_place(place, None, &mut |place| {
if let Some(vi) = map.places[place].value_index {
values[vi] = value.clone();
}
});
}
pub fn flood_idx(&mut self, place: PlaceIndex, map: &Map) {
self.flood_idx_with(place, map, V::top())
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_with(place, map, V::top())
}
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
let StateData::Reachable(values) = &mut self.0 else { return };
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |place| {
if let Some(vi) = map.places[place].value_index {
values[vi] = value.clone();
}
});
}
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
self.flood_discr_with(place, map, V::top())
}
/// Low-level method that assigns to a place.
/// This does nothing if the place is not tracked.
///
/// The target place must have been flooded before calling this method.
pub fn insert_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
match result {
ValueOrPlace::Value(value) => self.insert_value_idx(target, value, map),
ValueOrPlace::Place(source) => self.insert_place_idx(target, source, map),
}
}
/// Low-level method that assigns a value to a place.
/// This does nothing if the place is not tracked.
///
/// The target place must have been flooded before calling this method.
pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
let StateData::Reachable(values) = &mut self.0 else { return };
if let Some(value_index) = map.places[target].value_index {
values[value_index] = value;
}
}
/// Copies `source` to `target`, including all tracked places beneath.
@ -473,50 +497,41 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
/// If `target` contains a place that is not contained in `source`, it will be overwritten with
/// Top. Also, because this will copy all entries one after another, it may only be used for
/// places that are non-overlapping or identical.
pub fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
///
/// The target place must have been flooded before calling this method.
fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) {
let StateData::Reachable(values) = &mut self.0 else { return };
// If both places are tracked, we copy the value to the target. If the target is tracked,
// but the source is not, we have to invalidate the value in target. If the target is not
// tracked, then we don't have to do anything.
// If both places are tracked, we copy the value to the target.
// If the target is tracked, but the source is not, we do nothing, as invalidation has
// already been performed.
if let Some(target_value) = map.places[target].value_index {
if let Some(source_value) = map.places[source].value_index {
values[target_value] = values[source_value].clone();
} else {
values[target_value] = V::top();
}
}
for target_child in map.children(target) {
// Try to find corresponding child and recurse. Reasoning is similar as above.
let projection = map.places[target_child].proj_elem.unwrap();
if let Some(source_child) = map.projections.get(&(source, projection)) {
self.assign_place_idx(target_child, *source_child, map);
} else {
self.flood_idx(target_child, map);
self.insert_place_idx(target_child, *source_child, map);
}
}
}
/// Helper method to interpret `target = result`.
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
self.flood(target, map);
if let Some(target) = map.find(target) {
self.assign_idx(target, result, map);
} else {
// We don't track this place nor any projections, assignment can be ignored.
self.insert_idx(target, result, map);
}
}
pub fn assign_idx(&mut self, target: PlaceIndex, result: ValueOrPlace<V>, map: &Map) {
match result {
ValueOrPlace::Value(value) => {
// First flood the target place in case we also track any projections (although
// this scenario is currently not well-supported by the API).
self.flood_idx(target, map);
let StateData::Reachable(values) = &mut self.0 else { return };
if let Some(value_index) = map.places[target].value_index {
values[value_index] = value;
}
}
ValueOrPlace::Place(source) => self.assign_place_idx(target, source, map),
/// Helper method for assignments to a discriminant.
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
self.flood_discr(target, map);
if let Some(target) = map.find_discr(target) {
self.insert_idx(target, result, map);
}
}
@ -525,6 +540,14 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::top())
}
/// Retrieve the value stored for a place, or if it is not tracked.
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
match map.find_discr(place) {
Some(place) => self.get_idx(place, map),
None => V::top(),
}
}
/// Retrieve the value stored for a place index, or if it is not tracked.
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
match &self.0 {
@ -581,15 +604,15 @@ impl Map {
/// This is currently the only way to create a [`Map`]. The way in which the tracked places are
/// chosen is an implementation detail and may not be relied upon (other than that their type
/// passes the filter).
#[instrument(skip_all, level = "debug")]
pub fn from_filter<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
filter: impl FnMut(Ty<'tcx>) -> bool,
place_limit: Option<usize>,
) -> Self {
let mut map = Self::new();
let exclude = excluded_locals(body);
map.register_with_filter(tcx, body, filter, &exclude);
map.register_with_filter(tcx, body, filter, exclude, place_limit);
debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len());
map
}
@ -600,20 +623,28 @@ impl Map {
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
mut filter: impl FnMut(Ty<'tcx>) -> bool,
exclude: &IndexVec<Local, bool>,
exclude: BitSet<Local>,
place_limit: Option<usize>,
) {
// We use this vector as stack, pushing and popping projections.
let mut projection = Vec::new();
for (local, decl) in body.local_decls.iter_enumerated() {
if !exclude[local] {
self.register_with_filter_rec(tcx, local, &mut projection, decl.ty, &mut filter);
if !exclude.contains(local) {
self.register_with_filter_rec(
tcx,
local,
&mut projection,
decl.ty,
&mut filter,
place_limit,
);
}
}
}
/// Potentially register the (local, projection) place and its fields, recursively.
///
/// Invariant: The projection must only contain fields.
/// Invariant: The projection must only contain trackable elements.
fn register_with_filter_rec<'tcx>(
&mut self,
tcx: TyCtxt<'tcx>,
@ -621,27 +652,56 @@ impl Map {
projection: &mut Vec<PlaceElem<'tcx>>,
ty: Ty<'tcx>,
filter: &mut impl FnMut(Ty<'tcx>) -> bool,
place_limit: Option<usize>,
) {
// Note: The framework supports only scalars for now.
if filter(ty) && ty.is_scalar() {
// We know that the projection only contains trackable elements.
let place = self.make_place(local, projection).unwrap();
if let Some(place_limit) = place_limit && self.value_count >= place_limit {
return
}
// Allocate a value slot if it doesn't have one.
if self.places[place].value_index.is_none() {
self.places[place].value_index = Some(self.value_count.into());
self.value_count += 1;
// We know that the projection only contains trackable elements.
let place = self.make_place(local, projection).unwrap();
// Allocate a value slot if it doesn't have one, and the user requested one.
if self.places[place].value_index.is_none() && filter(ty) {
self.places[place].value_index = Some(self.value_count.into());
self.value_count += 1;
}
if ty.is_enum() {
let discr_ty = ty.discriminant_ty(tcx);
if filter(discr_ty) {
let discr = *self
.projections
.entry((place, TrackElem::Discriminant))
.or_insert_with(|| {
// Prepend new child to the linked list.
let next = self.places.push(PlaceInfo::new(Some(TrackElem::Discriminant)));
self.places[next].next_sibling = self.places[place].first_child;
self.places[place].first_child = Some(next);
next
});
// Allocate a value slot if it doesn't have one.
if self.places[discr].value_index.is_none() {
self.places[discr].value_index = Some(self.value_count.into());
self.value_count += 1;
}
}
}
// Recurse with all fields of this place.
iter_fields(ty, tcx, |variant, field, ty| {
if variant.is_some() {
// Downcasts are currently not supported.
if let Some(variant) = variant {
projection.push(PlaceElem::Downcast(None, variant));
let _ = self.make_place(local, projection);
projection.push(PlaceElem::Field(field, ty));
self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit);
projection.pop();
projection.pop();
return;
}
projection.push(PlaceElem::Field(field, ty));
self.register_with_filter_rec(tcx, local, projection, ty, filter);
self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit);
projection.pop();
});
}
@ -684,23 +744,105 @@ impl Map {
}
/// Locates the given place, if it exists in the tree.
pub fn find(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
pub fn find_extra(
&self,
place: PlaceRef<'_>,
extra: impl IntoIterator<Item = TrackElem>,
) -> Option<PlaceIndex> {
let mut index = *self.locals.get(place.local)?.as_ref()?;
for &elem in place.projection {
index = self.apply(index, elem.try_into().ok()?)?;
}
for elem in extra {
index = self.apply(index, elem)?;
}
Some(index)
}
/// Locates the given place, if it exists in the tree.
pub fn find(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
self.find_extra(place, [])
}
/// Locates the given place and applies `Discriminant`, if it exists in the tree.
pub fn find_discr(&self, place: PlaceRef<'_>) -> Option<PlaceIndex> {
self.find_extra(place, [TrackElem::Discriminant])
}
/// Iterate over all direct children.
pub fn children(&self, parent: PlaceIndex) -> impl Iterator<Item = PlaceIndex> + '_ {
Children::new(self, parent)
}
/// Invoke a function on the given place and all places that may alias it.
///
/// In particular, when the given place has a variant downcast, we invoke the function on all
/// the other variants.
///
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
/// as such.
pub fn for_each_aliasing_place(
&self,
place: PlaceRef<'_>,
tail_elem: Option<TrackElem>,
f: &mut impl FnMut(PlaceIndex),
) {
if place.is_indirect() {
// We do not track indirect places.
return;
}
let Some(&Some(mut index)) = self.locals.get(place.local) else {
// The local is not tracked at all, so it does not alias anything.
return;
};
let elems = place
.projection
.iter()
.map(|&elem| elem.try_into())
.chain(tail_elem.map(Ok).into_iter());
for elem in elems {
// A field aliases the parent place.
f(index);
let Ok(elem) = elem else { return };
let sub = self.apply(index, elem);
if let TrackElem::Variant(..) | TrackElem::Discriminant = elem {
// Enum variant fields and enum discriminants alias each another.
self.for_each_variant_sibling(index, sub, f);
}
if let Some(sub) = sub {
index = sub
} else {
return;
}
}
self.preorder_invoke(index, f);
}
/// Invoke the given function on all the descendants of the given place, except one branch.
fn for_each_variant_sibling(
&self,
parent: PlaceIndex,
preserved_child: Option<PlaceIndex>,
f: &mut impl FnMut(PlaceIndex),
) {
for sibling in self.children(parent) {
let elem = self.places[sibling].proj_elem;
// Only invalidate variants and discriminant. Fields (for generators) are not
// invalidated by assignment to a variant.
if let Some(TrackElem::Variant(..) | TrackElem::Discriminant) = elem
// Only invalidate the other variants, the current one is fine.
&& Some(sibling) != preserved_child
{
self.preorder_invoke(sibling, f);
}
}
}
/// Invoke a function on the given place and all descendants.
pub fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) {
f(root);
for child in self.children(root) {
self.preorder_invoke(child, f);
@ -759,6 +901,7 @@ impl<'a> Iterator for Children<'a> {
}
/// Used as the result of an operand or r-value.
#[derive(Debug)]
pub enum ValueOrPlace<V> {
Value(V),
Place(PlaceIndex),
@ -776,6 +919,8 @@ impl<V: HasTop> ValueOrPlace<V> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum TrackElem {
Field(Field),
Variant(VariantIdx),
Discriminant,
}
impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
@ -784,6 +929,7 @@ impl<V, T> TryFrom<ProjectionElem<V, T>> for TrackElem {
fn try_from(value: ProjectionElem<V, T>) -> Result<Self, Self::Error> {
match value {
ProjectionElem::Field(field, _) => Ok(TrackElem::Field(field)),
ProjectionElem::Downcast(_, idx) => Ok(TrackElem::Variant(idx)),
_ => Err(()),
}
}
@ -824,26 +970,27 @@ pub fn iter_fields<'tcx>(
}
/// Returns all locals with projections that have their reference or address taken.
pub fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
pub fn excluded_locals(body: &Body<'_>) -> BitSet<Local> {
struct Collector {
result: IndexVec<Local, bool>,
result: BitSet<Local>,
}
impl<'tcx> Visitor<'tcx> for Collector {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if context.is_borrow()
if (context.is_borrow()
|| context.is_address_of()
|| context.is_drop()
|| context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput)
|| context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput))
&& !place.is_indirect()
{
// A pointer to a place could be used to access other places with the same local,
// hence we have to exclude the local completely.
self.result[place.local] = true;
self.result.insert(place.local);
}
}
}
let mut collector = Collector { result: IndexVec::from_elem(false, &body.local_decls) };
let mut collector = Collector { result: BitSet::new_empty(body.local_decls.len()) };
collector.visit_body(body);
collector.result
}
@ -899,6 +1046,12 @@ fn debug_with_context_rec<V: Debug + Eq>(
for child in map.children(place) {
let info_elem = map.places[child].proj_elem.unwrap();
let child_place_str = match info_elem {
TrackElem::Discriminant => {
format!("discriminant({})", place_str)
}
TrackElem::Variant(idx) => {
format!("({} as {:?})", place_str, idx)
}
TrackElem::Field(field) => {
if place_str.starts_with('*') {
format!("({}).{}", place_str, field.index())