Only exclude locals if the place is not indirect.

This commit is contained in:
Camille GILLOT 2023-01-21 21:53:26 +00:00
parent 0d59b8c997
commit cd3649b2a5
3 changed files with 24 additions and 18 deletions

View file

@ -1,5 +1,5 @@
use crate::MirPass;
use rustc_index::bit_set::BitSet;
use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
debug!(?replacements);
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
if !all_dead_locals.is_empty() {
for local in excluded.indices() {
excluded[local] |= all_dead_locals.contains(local);
}
excluded.raw.resize(body.local_decls.len(), false);
excluded.union(&all_dead_locals);
excluded = {
let mut growable = GrowableBitSet::from(excluded);
growable.ensure(body.local_decls.len());
growable.into()
};
} else {
break;
}
@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
/// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code.
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() {
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local);
}
}
@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>(
body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>,
) -> BitSet<Local> {
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
for (local, replacements) in replacements.fragments.iter_enumerated() {
if replacements.is_some() {
all_dead_locals.insert(local);