1
Fork 0

Transforms a match containing negative numbers into an assignment statement as well

This commit is contained in:
DianQK 2024-02-20 21:55:46 +08:00
parent 1f061f47e2
commit e752af765e
No known key found for this signature in database
4 changed files with 100 additions and 59 deletions

View file

@ -1,6 +1,7 @@
use rustc_index::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
use rustc_target::abi::Size;
use std::iter;
use super::simplify::simplify_cfg;
@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
_ => unreachable!(),
};
if !self.can_simplify(tcx, targets, param_env, bbs) {
let discr_ty = discr.ty(local_decls, tcx);
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
return false;
}
// Take ownership of items now that we know we can optimize.
let discr = discr.clone();
let discr_ty = discr.ty(local_decls, tcx);
// Introduce a temporary for the discriminant value.
let source_info = bbs[switch_bb_idx].terminator().source_info;
@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool;
fn new_stmts(
@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
_discr_ty: Ty<'tcx>,
) -> bool {
if targets.iter().len() != 1 {
return false;
@ -268,7 +271,7 @@ struct SimplifyToExp {
enum CompareType<'tcx, 'a> {
Same(&'a StatementKind<'tcx>),
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
Discr(&'a Place<'tcx>, Ty<'tcx>),
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
}
enum TransfromType {
@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
match compare_type {
CompareType::Same(_) => TransfromType::Same,
CompareType::Eq(_, _, _) => TransfromType::Eq,
CompareType::Discr(_, _) => TransfromType::Discr,
CompareType::Discr(_, _, _) => TransfromType::Discr,
}
}
}
@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
targets: &SwitchTargets,
param_env: ParamEnv<'tcx>,
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
discr_ty: Ty<'tcx>,
) -> bool {
if targets.iter().len() < 2 || targets.iter().len() > 64 {
return false;
@ -355,6 +359,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return false;
}
let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
let first_stmts = &bbs[first_target].statements;
let (second_val, second_target) = target_iter.next().unwrap();
let second_stmts = &bbs[second_target].statements;
@ -362,6 +367,11 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return false;
}
fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
l.try_to_int(l.size()).unwrap()
== ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
}
let mut compare_types = Vec::new();
for (f, s) in iter::zip(first_stmts, second_stmts) {
let compare_type = match (&f.kind, &s.kind) {
@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
) {
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
(Some(f), Some(s))
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
&& int_equal(f, first_val, discr_size)
&& int_equal(s, second_val, discr_size))
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
&& Some(s)
== ScalarInt::try_from_uint(second_val, s.size())) =>
{
CompareType::Discr(lhs_f, f_c.const_.ty())
CompareType::Discr(
lhs_f,
f_c.const_.ty(),
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
)
}
_ => {
return false;
}
_ => return false,
}
}
@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
&& s_c.const_.ty() == f_ty
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
(
CompareType::Discr(lhs_f, f_ty),
CompareType::Discr(lhs_f, f_ty, is_signed),
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
return false;
};
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
return false;
if is_signed
&& s_c.const_.ty().is_signed()
&& int_equal(f, other_val, discr_size)
{
continue;
}
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
continue;
}
return false;
}
_ => return false,
}