Fix SROA without deaggregation.
This commit is contained in:
parent
3de7d7fb22
commit
0843acbea6
10 changed files with 175 additions and 128 deletions
|
@ -2,6 +2,7 @@ use crate::MirPass;
|
|||
use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
|
||||
use rustc_index::bit_set::BitSet;
|
||||
use rustc_index::vec::IndexVec;
|
||||
use rustc_middle::mir::patch::MirPatch;
|
||||
use rustc_middle::mir::visit::*;
|
||||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::TyCtxt;
|
||||
|
@ -13,7 +14,9 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
|
|||
sess.mir_opt_level() >= 3
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip(self, tcx, body))]
|
||||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
debug!(def_id = ?body.source.def_id());
|
||||
let escaping = escaping_locals(&*body);
|
||||
debug!(?escaping);
|
||||
let replacements = compute_flattening(tcx, body, escaping);
|
||||
|
@ -69,15 +72,28 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
|
|||
self.super_rvalue(rvalue, location)
|
||||
}
|
||||
|
||||
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
|
||||
if let StatementKind::StorageLive(..)
|
||||
| StatementKind::StorageDead(..)
|
||||
| StatementKind::Deinit(..) = statement.kind
|
||||
{
|
||||
// Storage statements are expanded in run_pass.
|
||||
fn visit_assign(
|
||||
&mut self,
|
||||
lvalue: &Place<'tcx>,
|
||||
rvalue: &Rvalue<'tcx>,
|
||||
location: Location,
|
||||
) {
|
||||
if lvalue.as_local().is_some() && let Rvalue::Aggregate(..) = rvalue {
|
||||
// Aggregate assignments are expanded in run_pass.
|
||||
self.visit_rvalue(rvalue, location);
|
||||
return;
|
||||
}
|
||||
self.super_statement(statement, location)
|
||||
self.super_assign(lvalue, rvalue, location)
|
||||
}
|
||||
|
||||
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
|
||||
match statement.kind {
|
||||
// Storage statements are expanded in run_pass.
|
||||
StatementKind::StorageLive(..)
|
||||
| StatementKind::StorageDead(..)
|
||||
| StatementKind::Deinit(..) => return,
|
||||
_ => self.super_statement(statement, location),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
|
||||
|
@ -192,6 +208,7 @@ fn replace_flattened_locals<'tcx>(
|
|||
replacements,
|
||||
all_dead_locals,
|
||||
fragments,
|
||||
patch: MirPatch::new(body),
|
||||
};
|
||||
for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
|
||||
visitor.visit_basic_block_data(bb, data);
|
||||
|
@ -205,6 +222,7 @@ fn replace_flattened_locals<'tcx>(
|
|||
for var_debug_info in &mut body.var_debug_info {
|
||||
visitor.visit_var_debug_info(var_debug_info);
|
||||
}
|
||||
visitor.patch.apply(body);
|
||||
}
|
||||
|
||||
struct ReplacementVisitor<'tcx, 'll> {
|
||||
|
@ -218,6 +236,7 @@ struct ReplacementVisitor<'tcx, 'll> {
|
|||
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
|
||||
/// and deinit statement and debuginfo.
|
||||
fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
|
||||
patch: MirPatch<'tcx>,
|
||||
}
|
||||
|
||||
impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
|
||||
|
@ -255,12 +274,63 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
|
|||
}
|
||||
|
||||
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
|
||||
if let StatementKind::StorageLive(..)
|
||||
| StatementKind::StorageDead(..)
|
||||
| StatementKind::Deinit(..) = statement.kind
|
||||
{
|
||||
// Storage statements are expanded in run_pass.
|
||||
return;
|
||||
match statement.kind {
|
||||
StatementKind::StorageLive(l) => {
|
||||
if self.all_dead_locals.contains(l) {
|
||||
let final_locals = &self.fragments[l];
|
||||
for &(_, fl) in final_locals {
|
||||
self.patch.add_statement(location, StatementKind::StorageLive(fl));
|
||||
}
|
||||
statement.make_nop();
|
||||
}
|
||||
return;
|
||||
}
|
||||
StatementKind::StorageDead(l) => {
|
||||
if self.all_dead_locals.contains(l) {
|
||||
let final_locals = &self.fragments[l];
|
||||
for &(_, fl) in final_locals {
|
||||
self.patch.add_statement(location, StatementKind::StorageDead(fl));
|
||||
}
|
||||
statement.make_nop();
|
||||
}
|
||||
return;
|
||||
}
|
||||
StatementKind::Deinit(box ref place) => {
|
||||
if let Some(local) = place.as_local()
|
||||
&& self.all_dead_locals.contains(local)
|
||||
{
|
||||
let final_locals = &self.fragments[local];
|
||||
for &(_, fl) in final_locals {
|
||||
self.patch.add_statement(
|
||||
location,
|
||||
StatementKind::Deinit(Box::new(fl.into())),
|
||||
);
|
||||
}
|
||||
statement.make_nop();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
StatementKind::Assign(box (ref place, Rvalue::Aggregate(_, ref operands))) => {
|
||||
if let Some(local) = place.as_local()
|
||||
&& self.all_dead_locals.contains(local)
|
||||
{
|
||||
let final_locals = &self.fragments[local];
|
||||
for &(projection, fl) in final_locals {
|
||||
let &[PlaceElem::Field(index, _)] = projection else { bug!() };
|
||||
let index = index.as_usize();
|
||||
let rvalue = Rvalue::Use(operands[index].clone());
|
||||
self.patch.add_statement(
|
||||
location,
|
||||
StatementKind::Assign(Box::new((fl.into(), rvalue))),
|
||||
);
|
||||
}
|
||||
statement.make_nop();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
self.super_statement(statement, location)
|
||||
}
|
||||
|
@ -309,39 +379,6 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
|
|||
}
|
||||
}
|
||||
|
||||
fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
|
||||
self.super_basic_block_data(bb, bbdata);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Stmt {
|
||||
StorageLive,
|
||||
StorageDead,
|
||||
Deinit,
|
||||
}
|
||||
|
||||
bbdata.expand_statements(|stmt| {
|
||||
let source_info = stmt.source_info;
|
||||
let (stmt, origin_local) = match &stmt.kind {
|
||||
StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
|
||||
StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
|
||||
StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
|
||||
_ => return None,
|
||||
};
|
||||
if !self.all_dead_locals.contains(origin_local) {
|
||||
return None;
|
||||
}
|
||||
let final_locals = self.fragments.get(origin_local)?;
|
||||
Some(final_locals.iter().map(move |&(_, l)| {
|
||||
let kind = match stmt {
|
||||
Stmt::StorageLive => StatementKind::StorageLive(l),
|
||||
Stmt::StorageDead => StatementKind::StorageDead(l),
|
||||
Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
|
||||
};
|
||||
Statement { source_info, kind }
|
||||
}))
|
||||
});
|
||||
}
|
||||
|
||||
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
|
||||
assert!(!self.all_dead_locals.contains(*local));
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue