Use MirPatch in EnumSizeOpt.

Instead of `expand_statements`. This makes the code shorter and
consistent with other MIR transform passes.

The tests require updating because there is a slight change in
MIR output:
- the old code replaced the original statement with twelve new
  statements.
- the new code inserts converts the original statement to a `nop` and
  then insert twelve new statements in front of it.

I.e. we now end up with an extra `nop`, which doesn't matter at all.
This commit is contained in:
Nicholas Nethercote 2025-02-18 11:24:57 +11:00
parent ce36a966c7
commit a1daa34ad0
5 changed files with 77 additions and 107 deletions

View file

@ -6,6 +6,8 @@ use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
use rustc_session::Session;
use crate::patch::MirPatch;
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
/// enough discrepancy between them.
///
@ -41,31 +43,34 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
let mut alloc_cache = FxHashMap::default();
let typing_env = body.typing_env(tcx);
let blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
let mut patch = MirPatch::new(body);
for bb in blocks {
bb.expand_statements(|st| {
for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
for (statement_index, st) in data.statements.iter_mut().enumerate() {
let StatementKind::Assign(box (
lhs,
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) = &st.kind
else {
return None;
continue;
};
let ty = lhs.ty(local_decls, tcx).ty;
let location = Location { block, statement_index };
let (adt_def, num_variants, alloc_id) =
self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
let ty = lhs.ty(&body.local_decls, tcx).ty;
let source_info = st.source_info;
let span = source_info.span;
let Some((adt_def, num_variants, alloc_id)) =
self.candidate(tcx, typing_env, ty, &mut alloc_cache)
else {
continue;
};
let span = st.source_info.span;
let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
let store_live =
Statement { source_info, kind: StatementKind::StorageLive(size_array_local) };
let size_array_local = patch.new_temp(tmp_ty, span);
let store_live = StatementKind::StorageLive(size_array_local);
let place = Place::from(size_array_local);
let constant_vals = ConstOperand {
@ -77,108 +82,63 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
),
};
let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
let const_assign =
Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };
let const_assign = StatementKind::Assign(Box::new((place, rval)));
let discr_place = Place::from(
local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
);
let store_discr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
discr_place,
Rvalue::Discriminant(*rhs),
))),
};
let discr_place =
Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
let store_discr =
StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
let discr_cast_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
let cast_discr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
discr_cast_place,
Rvalue::Cast(
CastKind::IntToInt,
Operand::Copy(discr_place),
tcx.types.usize,
),
))),
};
let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
let cast_discr = StatementKind::Assign(Box::new((
discr_cast_place,
Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
)));
let size_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
let store_size = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
size_place,
Rvalue::Use(Operand::Copy(Place {
local: size_array_local,
projection: tcx
.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
})),
))),
};
let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
let store_size = StatementKind::Assign(Box::new((
size_place,
Rvalue::Use(Operand::Copy(Place {
local: size_array_local,
projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
})),
)));
let dst =
Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span)));
let dst_ptr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
dst,
Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
))),
};
let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
let dst_ptr =
StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
let dst_cast_place =
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
let dst_cast = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
))),
};
let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
let dst_cast = StatementKind::Assign(Box::new((
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
)));
let src =
Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span)));
let src_ptr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
src,
Rvalue::RawPtr(RawPtrKind::Const, *rhs),
))),
};
let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
let src_ptr =
StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
let src_cast_place =
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
let src_cast = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
))),
};
let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
let src_cast = StatementKind::Assign(Box::new((
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
)));
let deinit_old =
Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
let deinit_old = StatementKind::Deinit(Box::new(dst));
let copy_bytes = Statement {
source_info,
kind: StatementKind::Intrinsic(Box::new(
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
src: Operand::Copy(src_cast_place),
dst: Operand::Copy(dst_cast_place),
count: Operand::Copy(size_place),
}),
)),
};
let copy_bytes = StatementKind::Intrinsic(Box::new(
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
src: Operand::Copy(src_cast_place),
dst: Operand::Copy(dst_cast_place),
count: Operand::Copy(size_place),
}),
));
let store_dead =
Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
let store_dead = StatementKind::StorageDead(size_array_local);
let iter = [
let stmts = [
store_live,
const_assign,
store_discr,
@ -191,14 +151,16 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
deinit_old,
copy_bytes,
store_dead,
]
.into_iter();
];
for stmt in stmts {
patch.add_statement(location, stmt);
}
st.make_nop();
Some(iter)
});
}
}
patch.apply(body);
}
fn is_required(&self) -> bool {

View file

@ -47,6 +47,7 @@
+ Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4);
+ nop;
StorageDead(_2);
- _0 = move _1;
+ StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12);
+ nop;
StorageDead(_1);
return;
}

View file

@ -47,6 +47,7 @@
+ Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4);
+ nop;
StorageDead(_2);
- _0 = move _1;
+ StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12);
+ nop;
StorageDead(_1);
return;
}

View file

@ -47,6 +47,7 @@
+ Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4);
+ nop;
StorageDead(_2);
- _0 = move _1;
+ StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12);
+ nop;
StorageDead(_1);
return;
}

View file

@ -47,6 +47,7 @@
+ Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4);
+ nop;
StorageDead(_2);
- _0 = move _1;
+ StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12);
+ nop;
StorageDead(_1);
return;
}