1
Fork 0

Fix SROA without deaggregation.

This commit is contained in:
Camille GILLOT 2023-02-04 14:39:42 +00:00
parent 3de7d7fb22
commit 0843acbea6
10 changed files with 175 additions and 128 deletions

View file

@ -2,6 +2,7 @@ use crate::MirPass;
use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
use rustc_index::bit_set::BitSet;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
@ -13,7 +14,9 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
sess.mir_opt_level() >= 3
}
#[instrument(level = "debug", skip(self, tcx, body))]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(def_id = ?body.source.def_id());
let escaping = escaping_locals(&*body);
debug!(?escaping);
let replacements = compute_flattening(tcx, body, escaping);
@ -69,15 +72,28 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
self.super_rvalue(rvalue, location)
}
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
if let StatementKind::StorageLive(..)
| StatementKind::StorageDead(..)
| StatementKind::Deinit(..) = statement.kind
{
// Storage statements are expanded in run_pass.
fn visit_assign(
&mut self,
lvalue: &Place<'tcx>,
rvalue: &Rvalue<'tcx>,
location: Location,
) {
if lvalue.as_local().is_some() && let Rvalue::Aggregate(..) = rvalue {
// Aggregate assignments are expanded in run_pass.
self.visit_rvalue(rvalue, location);
return;
}
self.super_statement(statement, location)
self.super_assign(lvalue, rvalue, location)
}
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
match statement.kind {
// Storage statements are expanded in run_pass.
StatementKind::StorageLive(..)
| StatementKind::StorageDead(..)
| StatementKind::Deinit(..) => return,
_ => self.super_statement(statement, location),
}
}
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
@ -192,6 +208,7 @@ fn replace_flattened_locals<'tcx>(
replacements,
all_dead_locals,
fragments,
patch: MirPatch::new(body),
};
for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
visitor.visit_basic_block_data(bb, data);
@ -205,6 +222,7 @@ fn replace_flattened_locals<'tcx>(
for var_debug_info in &mut body.var_debug_info {
visitor.visit_var_debug_info(var_debug_info);
}
visitor.patch.apply(body);
}
struct ReplacementVisitor<'tcx, 'll> {
@ -218,6 +236,7 @@ struct ReplacementVisitor<'tcx, 'll> {
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
/// and deinit statement and debuginfo.
fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
patch: MirPatch<'tcx>,
}
impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
@ -255,12 +274,63 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
}
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
if let StatementKind::StorageLive(..)
| StatementKind::StorageDead(..)
| StatementKind::Deinit(..) = statement.kind
{
// Storage statements are expanded in run_pass.
return;
match statement.kind {
StatementKind::StorageLive(l) => {
if self.all_dead_locals.contains(l) {
let final_locals = &self.fragments[l];
for &(_, fl) in final_locals {
self.patch.add_statement(location, StatementKind::StorageLive(fl));
}
statement.make_nop();
}
return;
}
StatementKind::StorageDead(l) => {
if self.all_dead_locals.contains(l) {
let final_locals = &self.fragments[l];
for &(_, fl) in final_locals {
self.patch.add_statement(location, StatementKind::StorageDead(fl));
}
statement.make_nop();
}
return;
}
StatementKind::Deinit(box ref place) => {
if let Some(local) = place.as_local()
&& self.all_dead_locals.contains(local)
{
let final_locals = &self.fragments[local];
for &(_, fl) in final_locals {
self.patch.add_statement(
location,
StatementKind::Deinit(Box::new(fl.into())),
);
}
statement.make_nop();
return;
}
}
StatementKind::Assign(box (ref place, Rvalue::Aggregate(_, ref operands))) => {
if let Some(local) = place.as_local()
&& self.all_dead_locals.contains(local)
{
let final_locals = &self.fragments[local];
for &(projection, fl) in final_locals {
let &[PlaceElem::Field(index, _)] = projection else { bug!() };
let index = index.as_usize();
let rvalue = Rvalue::Use(operands[index].clone());
self.patch.add_statement(
location,
StatementKind::Assign(Box::new((fl.into(), rvalue))),
);
}
statement.make_nop();
return;
}
}
_ => {}
}
self.super_statement(statement, location)
}
@ -309,39 +379,6 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
}
}
fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
self.super_basic_block_data(bb, bbdata);
#[derive(Debug)]
enum Stmt {
StorageLive,
StorageDead,
Deinit,
}
bbdata.expand_statements(|stmt| {
let source_info = stmt.source_info;
let (stmt, origin_local) = match &stmt.kind {
StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
_ => return None,
};
if !self.all_dead_locals.contains(origin_local) {
return None;
}
let final_locals = self.fragments.get(origin_local)?;
Some(final_locals.iter().map(move |&(_, l)| {
let kind = match stmt {
Stmt::StorageLive => StatementKind::StorageLive(l),
Stmt::StorageDead => StatementKind::StorageDead(l),
Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
};
Statement { source_info, kind }
}))
});
}
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
assert!(!self.all_dead_locals.contains(*local));
}