1
Fork 0

Use MirPatch in EnumSizeOpt.

Instead of `expand_statements`. This makes the code shorter and
consistent with other MIR transform passes.

The tests require updating because there is a slight change in
MIR output:
- the old code replaced the original statement with twelve new
  statements.
- the new code inserts converts the original statement to a `nop` and
  then insert twelve new statements in front of it.

I.e. we now end up with an extra `nop`, which doesn't matter at all.
This commit is contained in:
Nicholas Nethercote 2025-02-18 11:24:57 +11:00
parent ce36a966c7
commit a1daa34ad0
5 changed files with 77 additions and 107 deletions

View file

@ -6,6 +6,8 @@ use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
use rustc_session::Session; use rustc_session::Session;
use crate::patch::MirPatch;
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large /// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
/// enough discrepancy between them. /// enough discrepancy between them.
/// ///
@ -41,31 +43,34 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
let mut alloc_cache = FxHashMap::default(); let mut alloc_cache = FxHashMap::default();
let typing_env = body.typing_env(tcx); let typing_env = body.typing_env(tcx);
let blocks = body.basic_blocks.as_mut(); let mut patch = MirPatch::new(body);
let local_decls = &mut body.local_decls;
for bb in blocks { for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
bb.expand_statements(|st| { for (statement_index, st) in data.statements.iter_mut().enumerate() {
let StatementKind::Assign(box ( let StatementKind::Assign(box (
lhs, lhs,
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) = &st.kind )) = &st.kind
else { else {
return None; continue;
}; };
let ty = lhs.ty(local_decls, tcx).ty; let location = Location { block, statement_index };
let (adt_def, num_variants, alloc_id) = let ty = lhs.ty(&body.local_decls, tcx).ty;
self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
let source_info = st.source_info; let Some((adt_def, num_variants, alloc_id)) =
let span = source_info.span; self.candidate(tcx, typing_env, ty, &mut alloc_cache)
else {
continue;
};
let span = st.source_info.span;
let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64); let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span)); let size_array_local = patch.new_temp(tmp_ty, span);
let store_live =
Statement { source_info, kind: StatementKind::StorageLive(size_array_local) }; let store_live = StatementKind::StorageLive(size_array_local);
let place = Place::from(size_array_local); let place = Place::from(size_array_local);
let constant_vals = ConstOperand { let constant_vals = ConstOperand {
@ -77,108 +82,63 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
), ),
}; };
let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals))); let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
let const_assign = let const_assign = StatementKind::Assign(Box::new((place, rval)));
Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };
let discr_place = Place::from( let discr_place =
local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)), Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
); let store_discr =
let store_discr = Statement { StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
source_info,
kind: StatementKind::Assign(Box::new((
discr_place,
Rvalue::Discriminant(*rhs),
))),
};
let discr_cast_place = let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); let cast_discr = StatementKind::Assign(Box::new((
let cast_discr = Statement { discr_cast_place,
source_info, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
kind: StatementKind::Assign(Box::new(( )));
discr_cast_place,
Rvalue::Cast(
CastKind::IntToInt,
Operand::Copy(discr_place),
tcx.types.usize,
),
))),
};
let size_place = let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); let store_size = StatementKind::Assign(Box::new((
let store_size = Statement { size_place,
source_info, Rvalue::Use(Operand::Copy(Place {
kind: StatementKind::Assign(Box::new(( local: size_array_local,
size_place, projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
Rvalue::Use(Operand::Copy(Place { })),
local: size_array_local, )));
projection: tcx
.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
})),
))),
};
let dst = let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span))); let dst_ptr =
let dst_ptr = Statement { StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
source_info,
kind: StatementKind::Assign(Box::new((
dst,
Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
))),
};
let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8); let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
let dst_cast_place = let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span))); let dst_cast = StatementKind::Assign(Box::new((
let dst_cast = Statement { dst_cast_place,
source_info, Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
kind: StatementKind::Assign(Box::new(( )));
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
))),
};
let src = let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span))); let src_ptr =
let src_ptr = Statement { StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
source_info,
kind: StatementKind::Assign(Box::new((
src,
Rvalue::RawPtr(RawPtrKind::Const, *rhs),
))),
};
let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8); let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
let src_cast_place = let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span))); let src_cast = StatementKind::Assign(Box::new((
let src_cast = Statement { src_cast_place,
source_info, Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
kind: StatementKind::Assign(Box::new(( )));
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
))),
};
let deinit_old = let deinit_old = StatementKind::Deinit(Box::new(dst));
Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
let copy_bytes = Statement { let copy_bytes = StatementKind::Intrinsic(Box::new(
source_info, NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
kind: StatementKind::Intrinsic(Box::new( src: Operand::Copy(src_cast_place),
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { dst: Operand::Copy(dst_cast_place),
src: Operand::Copy(src_cast_place), count: Operand::Copy(size_place),
dst: Operand::Copy(dst_cast_place), }),
count: Operand::Copy(size_place), ));
}),
)),
};
let store_dead = let store_dead = StatementKind::StorageDead(size_array_local);
Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
let iter = [ let stmts = [
store_live, store_live,
const_assign, const_assign,
store_discr, store_discr,
@ -191,14 +151,16 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
deinit_old, deinit_old,
copy_bytes, copy_bytes,
store_dead, store_dead,
] ];
.into_iter(); for stmt in stmts {
patch.add_statement(location, stmt);
}
st.make_nop(); st.make_nop();
}
Some(iter)
});
} }
patch.apply(body);
} }
fn is_required(&self) -> bool { fn is_required(&self) -> bool {

View file

@ -47,6 +47,7 @@
+ Deinit(_8); + Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7); + copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4); + StorageDead(_4);
+ nop;
StorageDead(_2); StorageDead(_2);
- _0 = move _1; - _0 = move _1;
+ StorageLive(_12); + StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16); + Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15); + copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12); + StorageDead(_12);
+ nop;
StorageDead(_1); StorageDead(_1);
return; return;
} }

View file

@ -47,6 +47,7 @@
+ Deinit(_8); + Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7); + copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4); + StorageDead(_4);
+ nop;
StorageDead(_2); StorageDead(_2);
- _0 = move _1; - _0 = move _1;
+ StorageLive(_12); + StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16); + Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15); + copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12); + StorageDead(_12);
+ nop;
StorageDead(_1); StorageDead(_1);
return; return;
} }

View file

@ -47,6 +47,7 @@
+ Deinit(_8); + Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7); + copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4); + StorageDead(_4);
+ nop;
StorageDead(_2); StorageDead(_2);
- _0 = move _1; - _0 = move _1;
+ StorageLive(_12); + StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16); + Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15); + copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12); + StorageDead(_12);
+ nop;
StorageDead(_1); StorageDead(_1);
return; return;
} }

View file

@ -47,6 +47,7 @@
+ Deinit(_8); + Deinit(_8);
+ copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7); + copy_nonoverlapping(dst = copy _9, src = copy _11, count = copy _7);
+ StorageDead(_4); + StorageDead(_4);
+ nop;
StorageDead(_2); StorageDead(_2);
- _0 = move _1; - _0 = move _1;
+ StorageLive(_12); + StorageLive(_12);
@ -61,6 +62,7 @@
+ Deinit(_16); + Deinit(_16);
+ copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15); + copy_nonoverlapping(dst = copy _17, src = copy _19, count = copy _15);
+ StorageDead(_12); + StorageDead(_12);
+ nop;
StorageDead(_1); StorageDead(_1);
return; return;
} }