1
Fork 0

Only exclude locals if the place is not indirect.

This commit is contained in:
Camille GILLOT 2023-01-21 21:53:26 +00:00
parent 0d59b8c997
commit cd3649b2a5
3 changed files with 24 additions and 18 deletions

View file

@ -121,7 +121,9 @@ where
// for now. See discussion on [#61069]. // for now. See discussion on [#61069].
// //
// [#61069]: https://github.com/rust-lang/rust/pull/61069 // [#61069]: https://github.com/rust-lang/rust/pull/61069
self.trans.gen(dropped_place.local); if !dropped_place.is_indirect() {
self.trans.gen(dropped_place.local);
}
} }
TerminatorKind::Abort TerminatorKind::Abort

View file

@ -35,6 +35,7 @@
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::fx::FxHashMap;
use rustc_index::bit_set::BitSet;
use rustc_index::vec::IndexVec; use rustc_index::vec::IndexVec;
use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*; use rustc_middle::mir::*;
@ -589,7 +590,7 @@ impl Map {
) -> Self { ) -> Self {
let mut map = Self::new(); let mut map = Self::new();
let exclude = excluded_locals(body); let exclude = excluded_locals(body);
map.register_with_filter(tcx, body, filter, &exclude); map.register_with_filter(tcx, body, filter, exclude);
debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len()); debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len());
map map
} }
@ -600,12 +601,12 @@ impl Map {
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
body: &Body<'tcx>, body: &Body<'tcx>,
mut filter: impl FnMut(Ty<'tcx>) -> bool, mut filter: impl FnMut(Ty<'tcx>) -> bool,
exclude: &IndexVec<Local, bool>, exclude: BitSet<Local>,
) { ) {
// We use this vector as stack, pushing and popping projections. // We use this vector as stack, pushing and popping projections.
let mut projection = Vec::new(); let mut projection = Vec::new();
for (local, decl) in body.local_decls.iter_enumerated() { for (local, decl) in body.local_decls.iter_enumerated() {
if !exclude[local] { if !exclude.contains(local) {
self.register_with_filter_rec(tcx, local, &mut projection, decl.ty, &mut filter); self.register_with_filter_rec(tcx, local, &mut projection, decl.ty, &mut filter);
} }
} }
@ -823,26 +824,27 @@ pub fn iter_fields<'tcx>(
} }
/// Returns all locals with projections that have their reference or address taken. /// Returns all locals with projections that have their reference or address taken.
pub fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> { pub fn excluded_locals(body: &Body<'_>) -> BitSet<Local> {
struct Collector { struct Collector {
result: IndexVec<Local, bool>, result: BitSet<Local>,
} }
impl<'tcx> Visitor<'tcx> for Collector { impl<'tcx> Visitor<'tcx> for Collector {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if context.is_borrow() if (context.is_borrow()
|| context.is_address_of() || context.is_address_of()
|| context.is_drop() || context.is_drop()
|| context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput) || context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput))
&& !place.is_indirect()
{ {
// A pointer to a place could be used to access other places with the same local, // A pointer to a place could be used to access other places with the same local,
// hence we have to exclude the local completely. // hence we have to exclude the local completely.
self.result[place.local] = true; self.result.insert(place.local);
} }
} }
} }
let mut collector = Collector { result: IndexVec::from_elem(false, &body.local_decls) }; let mut collector = Collector { result: BitSet::new_empty(body.local_decls.len()) };
collector.visit_body(body); collector.visit_body(body);
collector.result collector.result
} }

View file

@ -1,5 +1,5 @@
use crate::MirPass; use crate::MirPass;
use rustc_index::bit_set::BitSet; use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec; use rustc_index::vec::IndexVec;
use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*; use rustc_middle::mir::visit::*;
@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
debug!(?replacements); debug!(?replacements);
let all_dead_locals = replace_flattened_locals(tcx, body, replacements); let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
if !all_dead_locals.is_empty() { if !all_dead_locals.is_empty() {
for local in excluded.indices() { excluded.union(&all_dead_locals);
excluded[local] |= all_dead_locals.contains(local); excluded = {
} let mut growable = GrowableBitSet::from(excluded);
excluded.raw.resize(body.local_decls.len(), false); growable.ensure(body.local_decls.len());
growable.into()
};
} else { } else {
break; break;
} }
@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
/// - the locals is a union or an enum; /// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to /// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code. /// client code.
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> { fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len()); let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() { for (local, decl) in body.local_decls().iter_enumerated() {
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] { if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local); set.insert(local);
} }
} }
@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
body: &mut Body<'tcx>, body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>, replacements: ReplacementMap<'tcx>,
) -> BitSet<Local> { ) -> BitSet<Local> {
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
for (local, replacements) in replacements.fragments.iter_enumerated() { for (local, replacements) in replacements.fragments.iter_enumerated() {
if replacements.is_some() { if replacements.is_some() {
all_dead_locals.insert(local); all_dead_locals.insert(local);