1
Fork 0

Make jump threading state sparse.

This commit is contained in:
Camille GILLOT 2024-06-26 21:57:59 +00:00
parent 1834f5a272
commit 76244d4dbc
3 changed files with 86 additions and 38 deletions

View file

@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq {
/// A set that has a "bottom" element, which is less than or equal to any other element. /// A set that has a "bottom" element, which is less than or equal to any other element.
pub trait HasBottom { pub trait HasBottom {
const BOTTOM: Self; const BOTTOM: Self;
fn is_bottom(&self) -> bool;
} }
/// A set that has a "top" element, which is greater than or equal to any other element. /// A set that has a "top" element, which is greater than or equal to any other element.
@ -114,6 +116,10 @@ impl MeetSemiLattice for bool {
impl HasBottom for bool { impl HasBottom for bool {
const BOTTOM: Self = false; const BOTTOM: Self = false;
fn is_bottom(&self) -> bool {
!self
}
} }
impl HasTop for bool { impl HasTop for bool {
@ -267,6 +273,10 @@ impl<T: Clone + Eq> MeetSemiLattice for FlatSet<T> {
impl<T> HasBottom for FlatSet<T> { impl<T> HasBottom for FlatSet<T> {
const BOTTOM: Self = Self::Bottom; const BOTTOM: Self = Self::Bottom;
fn is_bottom(&self) -> bool {
matches!(self, Self::Bottom)
}
} }
impl<T> HasTop for FlatSet<T> { impl<T> HasTop for FlatSet<T> {
@ -291,6 +301,10 @@ impl<T> MaybeReachable<T> {
impl<T> HasBottom for MaybeReachable<T> { impl<T> HasBottom for MaybeReachable<T> {
const BOTTOM: Self = MaybeReachable::Unreachable; const BOTTOM: Self = MaybeReachable::Unreachable;
fn is_bottom(&self) -> bool {
matches!(self, Self::Unreachable)
}
} }
impl<T: HasTop> HasTop for MaybeReachable<T> { impl<T: HasTop> HasTop for MaybeReachable<T> {

View file

@ -36,7 +36,7 @@ use std::collections::VecDeque;
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use std::ops::Range; use std::ops::Range;
use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::fx::{FxHashMap, StdEntry};
use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_index::bit_set::BitSet; use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec; use rustc_index::IndexVec;
@ -342,8 +342,7 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper
fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) {
// The initial state maps all tracked places of argument projections to and the rest to ⊥. // The initial state maps all tracked places of argument projections to and the rest to ⊥.
assert!(matches!(state, State::Unreachable)); assert!(matches!(state, State::Unreachable));
let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count); *state = State::new_reachable();
*state = State::Reachable(values);
for arg in body.args_iter() { for arg in body.args_iter() {
state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map()); state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map());
} }
@ -415,30 +414,54 @@ rustc_index::newtype_index!(
/// See [`State`]. /// See [`State`].
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
struct StateData<V> { pub struct StateData<V> {
map: IndexVec<ValueIndex, V>, bottom: V,
/// This map only contains values that are not `⊥`.
map: FxHashMap<ValueIndex, V>,
} }
impl<V: Clone> StateData<V> { impl<V: HasBottom> StateData<V> {
fn from_elem_n(elem: V, n: usize) -> StateData<V> { fn new() -> StateData<V> {
StateData { map: IndexVec::from_elem_n(elem, n) } StateData { bottom: V::BOTTOM, map: FxHashMap::default() }
}
fn get(&self, idx: ValueIndex) -> &V {
self.map.get(&idx).unwrap_or(&self.bottom)
}
fn insert(&mut self, idx: ValueIndex, elem: V) {
if elem.is_bottom() {
self.map.remove(&idx);
} else {
self.map.insert(idx, elem);
}
} }
} }
impl<V: Clone> Clone for StateData<V> { impl<V: Clone> Clone for StateData<V> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
StateData { map: self.map.clone() } StateData { bottom: self.bottom.clone(), map: self.map.clone() }
} }
fn clone_from(&mut self, source: &Self) { fn clone_from(&mut self, source: &Self) {
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`. self.map.clone_from(&source.map)
self.map.raw.clone_from(&source.map.raw)
} }
} }
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for StateData<V> { impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for StateData<V> {
fn join(&mut self, other: &Self) -> bool { fn join(&mut self, other: &Self) -> bool {
self.map.join(&other.map) let mut changed = false;
#[allow(rustc::potential_query_instability)]
for (i, v) in other.map.iter() {
match self.map.entry(*i) {
StdEntry::Vacant(e) => {
e.insert(v.clone());
changed = true
}
StdEntry::Occupied(e) => changed |= e.into_mut().join(v),
}
}
changed
} }
} }
@ -476,15 +499,19 @@ impl<V: Clone> Clone for State<V> {
} }
} }
impl<V: Clone> State<V> { impl<V: Clone + HasBottom> State<V> {
pub fn new(init: V, map: &Map) -> State<V> { pub fn new_reachable() -> State<V> {
State::Reachable(StateData::from_elem_n(init, map.value_count)) State::Reachable(StateData::new())
} }
pub fn all(&self, f: impl Fn(&V) -> bool) -> bool { pub fn all_bottom(&self) -> bool {
match self { match self {
State::Unreachable => true, State::Unreachable => false,
State::Reachable(ref values) => values.map.iter().all(f), State::Reachable(ref values) =>
{
#[allow(rustc::potential_query_instability)]
values.map.values().all(V::is_bottom)
}
} }
} }
@ -533,9 +560,7 @@ impl<V: Clone> State<V> {
value: V, value: V,
) { ) {
let State::Reachable(values) = self else { return }; let State::Reachable(values) = self else { return };
map.for_each_aliasing_place(place, tail_elem, &mut |vi| { map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone()));
values.map[vi] = value.clone();
});
} }
/// Low-level method that assigns to a place. /// Low-level method that assigns to a place.
@ -556,7 +581,7 @@ impl<V: Clone> State<V> {
pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) { pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) {
let State::Reachable(values) = self else { return }; let State::Reachable(values) = self else { return };
if let Some(value_index) = map.places[target].value_index { if let Some(value_index) = map.places[target].value_index {
values.map[value_index] = value; values.insert(value_index, value)
} }
} }
@ -575,7 +600,7 @@ impl<V: Clone> State<V> {
// already been performed. // already been performed.
if let Some(target_value) = map.places[target].value_index { if let Some(target_value) = map.places[target].value_index {
if let Some(source_value) = map.places[source].value_index { if let Some(source_value) = map.places[source].value_index {
values.map[target_value] = values.map[source_value].clone(); values.insert(target_value, values.get(source_value).clone());
} }
} }
for target_child in map.children(target) { for target_child in map.children(target) {
@ -631,7 +656,7 @@ impl<V: Clone> State<V> {
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> { pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
match self { match self {
State::Reachable(values) => { State::Reachable(values) => {
map.places[place].value_index.map(|v| values.map[v].clone()) map.places[place].value_index.map(|v| values.get(v).clone())
} }
State::Unreachable => None, State::Unreachable => None,
} }
@ -688,7 +713,7 @@ impl<V: Clone> State<V> {
{ {
match self { match self {
State::Reachable(values) => { State::Reachable(values) => {
map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP) map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP)
} }
State::Unreachable => { State::Unreachable => {
// Because this is unreachable, we can return any value we want. // Because this is unreachable, we can return any value we want.
@ -698,7 +723,7 @@ impl<V: Clone> State<V> {
} }
} }
impl<V: JoinSemiLattice + Clone> JoinSemiLattice for State<V> { impl<V: JoinSemiLattice + Clone + HasBottom> JoinSemiLattice for State<V> {
fn join(&mut self, other: &Self) -> bool { fn join(&mut self, other: &Self) -> bool {
match (&mut *self, other) { match (&mut *self, other) {
(_, State::Unreachable) => false, (_, State::Unreachable) => false,
@ -1228,7 +1253,7 @@ where
} }
} }
fn debug_with_context_rec<V: Debug + Eq>( fn debug_with_context_rec<V: Debug + Eq + HasBottom>(
place: PlaceIndex, place: PlaceIndex,
place_str: &str, place_str: &str,
new: &StateData<V>, new: &StateData<V>,
@ -1238,11 +1263,11 @@ fn debug_with_context_rec<V: Debug + Eq>(
) -> std::fmt::Result { ) -> std::fmt::Result {
if let Some(value) = map.places[place].value_index { if let Some(value) = map.places[place].value_index {
match old { match old {
None => writeln!(f, "{}: {:?}", place_str, new.map[value])?, None => writeln!(f, "{}: {:?}", place_str, new.get(value))?,
Some(old) => { Some(old) => {
if new.map[value] != old.map[value] { if new.get(value) != old.get(value) {
writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?; writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?;
writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?; writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?;
} }
} }
} }
@ -1274,7 +1299,7 @@ fn debug_with_context_rec<V: Debug + Eq>(
Ok(()) Ok(())
} }
fn debug_with_context<V: Debug + Eq>( fn debug_with_context<V: Debug + Eq + HasBottom>(
new: &StateData<V>, new: &StateData<V>,
old: Option<&StateData<V>>, old: Option<&StateData<V>>,
map: &Map, map: &Map,

View file

@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*; use rustc_middle::mir::*;
use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt}; use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::lattice::HasBottom;
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP; use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants}; use rustc_target::abi::{TagEncoding, Variants};
@ -158,9 +159,17 @@ impl Condition {
} }
} }
#[derive(Copy, Clone, Debug, Default)] #[derive(Copy, Clone, Debug)]
struct ConditionSet<'a>(&'a [Condition]); struct ConditionSet<'a>(&'a [Condition]);
impl HasBottom for ConditionSet<'_> {
const BOTTOM: Self = ConditionSet(&[]);
fn is_bottom(&self) -> bool {
self.0.is_empty()
}
}
impl<'a> ConditionSet<'a> { impl<'a> ConditionSet<'a> {
fn iter(self) -> impl Iterator<Item = Condition> + 'a { fn iter(self) -> impl Iterator<Item = Condition> + 'a {
self.0.iter().copied() self.0.iter().copied()
@ -177,7 +186,7 @@ impl<'a> ConditionSet<'a> {
impl<'tcx, 'a> TOFinder<'tcx, 'a> { impl<'tcx, 'a> TOFinder<'tcx, 'a> {
fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool { fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
state.all(|cs| cs.0.is_empty()) state.all_bottom()
} }
/// Recursion entry point to find threading opportunities. /// Recursion entry point to find threading opportunities.
@ -198,7 +207,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
debug!(?discr); debug!(?discr);
let cost = CostChecker::new(self.tcx, self.param_env, None, self.body); let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
let mut state = State::new(ConditionSet::default(), self.map); let mut state = State::new_reachable();
let conds = if let Some((value, then, else_)) = targets.as_static_if() { let conds = if let Some((value, then, else_)) = targets.as_static_if() {
let value = ScalarInt::try_from_uint(value, discr_layout.size)?; let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
@ -255,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`. // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
// _1 = 6 // _1 = 6
if let Some((lhs, tail)) = self.mutated_statement(stmt) { if let Some((lhs, tail)) = self.mutated_statement(stmt) {
state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default()); state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM);
} }
} }
@ -609,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// We can recurse through this terminator. // We can recurse through this terminator.
let mut state = state(); let mut state = state();
if let Some(place_to_flood) = place_to_flood { if let Some(place_to_flood) = place_to_flood {
state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default()); state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM);
} }
self.find_opportunity(bb, state, cost.clone(), depth + 1); self.find_opportunity(bb, state, cost.clone(), depth + 1);
} }