Auto merge of #85158 - JulianKnodt:array_const_val, r=cjgillot
Mir-Opt for copying enums with large discrepancies I have been meaning to make this for quite a while, based off of this [hackmd](https://hackmd.io/`@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).` I'm not sure where to put this opt now that I've made it, so I'd appreciate suggestions on that! It's also one long chain of statements, not sure if there's a more friendly format to make it. r? `@tmiasko` I would `r` oli but he's on leave so he suggested I `r` tmiasko or wesleywiser.
This commit is contained in:
commit
5a8dfd933a
11 changed files with 789 additions and 0 deletions
298
compiler/rustc_mir_transform/src/large_enums.rs
Normal file
298
compiler/rustc_mir_transform/src/large_enums.rs
Normal file
|
@ -0,0 +1,298 @@
|
|||
use crate::rustc_middle::ty::util::IntTypeExt;
|
||||
use crate::MirPass;
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_middle::mir::interpret::AllocId;
|
||||
use rustc_middle::mir::*;
|
||||
use rustc_middle::ty::{self, AdtDef, Const, ParamEnv, Ty, TyCtxt};
|
||||
use rustc_session::Session;
|
||||
use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};
|
||||
|
||||
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
|
||||
/// enough discrepancy between them.
|
||||
///
|
||||
/// i.e. If there is are two variants:
|
||||
/// ```
|
||||
/// enum Example {
|
||||
/// Small,
|
||||
/// Large([u32; 1024]),
|
||||
/// }
|
||||
/// ```
|
||||
/// Instead of emitting moves of the large variant,
|
||||
/// Perform a memcpy instead.
|
||||
/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
|
||||
///
|
||||
/// In summary, what this does is at runtime determine which enum variant is active,
|
||||
/// and instead of copying all the bytes of the largest possible variant,
|
||||
/// copy only the bytes for the currently active variant.
|
||||
pub struct EnumSizeOpt {
|
||||
pub(crate) discrepancy: u64,
|
||||
}
|
||||
|
||||
impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
|
||||
fn is_enabled(&self, sess: &Session) -> bool {
|
||||
sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
|
||||
}
|
||||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
// NOTE: This pass may produce different MIR based on the alignment of the target
|
||||
// platform, but it will still be valid.
|
||||
self.optim(tcx, body);
|
||||
}
|
||||
}
|
||||
|
||||
impl EnumSizeOpt {
|
||||
fn candidate<'tcx>(
|
||||
&self,
|
||||
tcx: TyCtxt<'tcx>,
|
||||
param_env: ParamEnv<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
|
||||
) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
|
||||
let adt_def = match ty.kind() {
|
||||
ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
|
||||
_ => return None,
|
||||
};
|
||||
let layout = tcx.layout_of(param_env.and(ty)).ok()?;
|
||||
let variants = match &layout.variants {
|
||||
Variants::Single { .. } => return None,
|
||||
Variants::Multiple { tag_encoding, .. }
|
||||
if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
|
||||
{
|
||||
return None;
|
||||
}
|
||||
Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
|
||||
Variants::Multiple { variants, .. } => variants,
|
||||
};
|
||||
let min = variants.iter().map(|v| v.size).min().unwrap();
|
||||
let max = variants.iter().map(|v| v.size).max().unwrap();
|
||||
if max.bytes() - min.bytes() < self.discrepancy {
|
||||
return None;
|
||||
}
|
||||
|
||||
let num_discrs = adt_def.discriminants(tcx).count();
|
||||
if variants.iter_enumerated().any(|(var_idx, _)| {
|
||||
let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
|
||||
(discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
|
||||
}) {
|
||||
return None;
|
||||
}
|
||||
if let Some(alloc_id) = alloc_cache.get(&ty) {
|
||||
return Some((*adt_def, num_discrs, *alloc_id));
|
||||
}
|
||||
|
||||
let data_layout = tcx.data_layout();
|
||||
let ptr_sized_int = data_layout.ptr_sized_integer();
|
||||
let target_bytes = ptr_sized_int.size().bytes() as usize;
|
||||
let mut data = vec![0; target_bytes * num_discrs];
|
||||
macro_rules! encode_store {
|
||||
($curr_idx: expr, $endian: expr, $bytes: expr) => {
|
||||
let bytes = match $endian {
|
||||
rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
|
||||
rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
|
||||
};
|
||||
for (i, b) in bytes.into_iter().enumerate() {
|
||||
data[$curr_idx + i] = b;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for (var_idx, layout) in variants.iter_enumerated() {
|
||||
let curr_idx =
|
||||
target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
|
||||
let sz = layout.size;
|
||||
match ptr_sized_int {
|
||||
rustc_target::abi::Integer::I32 => {
|
||||
encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
|
||||
}
|
||||
rustc_target::abi::Integer::I64 => {
|
||||
encode_store!(curr_idx, data_layout.endian, sz.bytes());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
let alloc = interpret::Allocation::from_bytes(
|
||||
data,
|
||||
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
|
||||
Mutability::Not,
|
||||
);
|
||||
let alloc = tcx.create_memory_alloc(tcx.intern_const_alloc(alloc));
|
||||
Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
|
||||
}
|
||||
fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
||||
let mut alloc_cache = FxHashMap::default();
|
||||
let body_did = body.source.def_id();
|
||||
let param_env = tcx.param_env(body_did);
|
||||
|
||||
let blocks = body.basic_blocks.as_mut();
|
||||
let local_decls = &mut body.local_decls;
|
||||
|
||||
for bb in blocks {
|
||||
bb.expand_statements(|st| {
|
||||
if let StatementKind::Assign(box (
|
||||
lhs,
|
||||
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
|
||||
)) = &st.kind
|
||||
{
|
||||
let ty = lhs.ty(local_decls, tcx).ty;
|
||||
|
||||
let source_info = st.source_info;
|
||||
let span = source_info.span;
|
||||
|
||||
let (adt_def, num_variants, alloc_id) =
|
||||
self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
|
||||
let alloc = tcx.global_alloc(alloc_id).unwrap_memory();
|
||||
|
||||
let tmp_ty = tcx.mk_ty(ty::Array(
|
||||
tcx.types.usize,
|
||||
Const::from_usize(tcx, num_variants as u64),
|
||||
));
|
||||
|
||||
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
|
||||
let store_live = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::StorageLive(size_array_local),
|
||||
};
|
||||
|
||||
let place = Place::from(size_array_local);
|
||||
let constant_vals = Constant {
|
||||
span,
|
||||
user_ty: None,
|
||||
literal: ConstantKind::Val(
|
||||
interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
|
||||
tmp_ty,
|
||||
),
|
||||
};
|
||||
let rval = Rvalue::Use(Operand::Constant(box (constant_vals)));
|
||||
|
||||
let const_assign =
|
||||
Statement { source_info, kind: StatementKind::Assign(box (place, rval)) };
|
||||
|
||||
let discr_place = Place::from(
|
||||
local_decls
|
||||
.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
|
||||
);
|
||||
|
||||
let store_discr = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))),
|
||||
};
|
||||
|
||||
let discr_cast_place =
|
||||
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
|
||||
|
||||
let cast_discr = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
discr_cast_place,
|
||||
Rvalue::Cast(
|
||||
CastKind::IntToInt,
|
||||
Operand::Copy(discr_place),
|
||||
tcx.types.usize,
|
||||
),
|
||||
)),
|
||||
};
|
||||
|
||||
let size_place =
|
||||
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
|
||||
|
||||
let store_size = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
size_place,
|
||||
Rvalue::Use(Operand::Copy(Place {
|
||||
local: size_array_local,
|
||||
projection: tcx.intern_place_elems(&[PlaceElem::Index(
|
||||
discr_cast_place.local,
|
||||
)]),
|
||||
})),
|
||||
)),
|
||||
};
|
||||
|
||||
let dst =
|
||||
Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));
|
||||
|
||||
let dst_ptr = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
dst,
|
||||
Rvalue::AddressOf(Mutability::Mut, *lhs),
|
||||
)),
|
||||
};
|
||||
|
||||
let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
|
||||
let dst_cast_place =
|
||||
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
|
||||
|
||||
let dst_cast = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
dst_cast_place,
|
||||
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
|
||||
)),
|
||||
};
|
||||
|
||||
let src =
|
||||
Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));
|
||||
|
||||
let src_ptr = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
src,
|
||||
Rvalue::AddressOf(Mutability::Not, *rhs),
|
||||
)),
|
||||
};
|
||||
|
||||
let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
|
||||
let src_cast_place =
|
||||
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
|
||||
|
||||
let src_cast = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Assign(box (
|
||||
src_cast_place,
|
||||
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
|
||||
)),
|
||||
};
|
||||
|
||||
let deinit_old =
|
||||
Statement { source_info, kind: StatementKind::Deinit(box dst) };
|
||||
|
||||
let copy_bytes = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::Intrinsic(
|
||||
box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
|
||||
src: Operand::Copy(src_cast_place),
|
||||
dst: Operand::Copy(dst_cast_place),
|
||||
count: Operand::Copy(size_place),
|
||||
}),
|
||||
),
|
||||
};
|
||||
|
||||
let store_dead = Statement {
|
||||
source_info,
|
||||
kind: StatementKind::StorageDead(size_array_local),
|
||||
};
|
||||
let iter = [
|
||||
store_live,
|
||||
const_assign,
|
||||
store_discr,
|
||||
cast_discr,
|
||||
store_size,
|
||||
dst_ptr,
|
||||
dst_cast,
|
||||
src_ptr,
|
||||
src_cast,
|
||||
deinit_old,
|
||||
copy_bytes,
|
||||
store_dead,
|
||||
]
|
||||
.into_iter();
|
||||
|
||||
st.make_nop();
|
||||
Some(iter)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
#![allow(rustc::potential_query_instability)]
|
||||
#![feature(box_patterns)]
|
||||
#![feature(drain_filter)]
|
||||
#![feature(box_syntax)]
|
||||
#![feature(let_chains)]
|
||||
#![feature(map_try_insert)]
|
||||
#![feature(min_specialization)]
|
||||
|
@ -73,6 +74,7 @@ mod function_item_references;
|
|||
mod generator;
|
||||
mod inline;
|
||||
mod instcombine;
|
||||
mod large_enums;
|
||||
mod lower_intrinsics;
|
||||
mod lower_slice_len;
|
||||
mod match_branches;
|
||||
|
@ -583,6 +585,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
|
|||
&simplify::SimplifyLocals::new("final"),
|
||||
&multiple_return_terminators::MultipleReturnTerminators,
|
||||
&deduplicate_blocks::DeduplicateBlocks,
|
||||
&large_enums::EnumSizeOpt { discrepancy: 128 },
|
||||
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
|
||||
&add_call_guards::CriticalCallEdges,
|
||||
// Dump the end result for testing and debugging purposes.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue