Use an interpreter in jump threading.
This commit is contained in:
parent
25f8d01fd8
commit
be9668d398
4 changed files with 197 additions and 27 deletions
|
@ -36,16 +36,21 @@
|
||||||
//! cost by `MAX_COST`.
|
//! cost by `MAX_COST`.
|
||||||
|
|
||||||
use rustc_arena::DroplessArena;
|
use rustc_arena::DroplessArena;
|
||||||
|
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
|
||||||
use rustc_data_structures::fx::FxHashSet;
|
use rustc_data_structures::fx::FxHashSet;
|
||||||
use rustc_index::bit_set::BitSet;
|
use rustc_index::bit_set::BitSet;
|
||||||
use rustc_index::IndexVec;
|
use rustc_index::IndexVec;
|
||||||
|
use rustc_middle::mir::interpret::Scalar;
|
||||||
use rustc_middle::mir::visit::Visitor;
|
use rustc_middle::mir::visit::Visitor;
|
||||||
use rustc_middle::mir::*;
|
use rustc_middle::mir::*;
|
||||||
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
|
use rustc_middle::ty::layout::LayoutOf;
|
||||||
|
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
|
||||||
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
|
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
|
||||||
|
use rustc_span::DUMMY_SP;
|
||||||
use rustc_target::abi::{TagEncoding, Variants};
|
use rustc_target::abi::{TagEncoding, Variants};
|
||||||
|
|
||||||
use crate::cost_checker::CostChecker;
|
use crate::cost_checker::CostChecker;
|
||||||
|
use crate::dataflow_const_prop::DummyMachine;
|
||||||
|
|
||||||
pub struct JumpThreading;
|
pub struct JumpThreading;
|
||||||
|
|
||||||
|
@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
|
||||||
let mut finder = TOFinder {
|
let mut finder = TOFinder {
|
||||||
tcx,
|
tcx,
|
||||||
param_env,
|
param_env,
|
||||||
|
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
|
||||||
body,
|
body,
|
||||||
arena: &arena,
|
arena: &arena,
|
||||||
map: &map,
|
map: &map,
|
||||||
|
@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
|
||||||
debug!(?discr, ?bb);
|
debug!(?discr, ?bb);
|
||||||
|
|
||||||
let discr_ty = discr.ty(body, tcx).ty;
|
let discr_ty = discr.ty(body, tcx).ty;
|
||||||
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
|
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
|
||||||
|
|
||||||
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
|
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
|
||||||
debug!(?discr);
|
debug!(?discr);
|
||||||
|
@ -142,6 +148,7 @@ struct ThreadingOpportunity {
|
||||||
struct TOFinder<'tcx, 'a> {
|
struct TOFinder<'tcx, 'a> {
|
||||||
tcx: TyCtxt<'tcx>,
|
tcx: TyCtxt<'tcx>,
|
||||||
param_env: ty::ParamEnv<'tcx>,
|
param_env: ty::ParamEnv<'tcx>,
|
||||||
|
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
|
||||||
body: &'a Body<'tcx>,
|
body: &'a Body<'tcx>,
|
||||||
map: &'a Map,
|
map: &'a Map,
|
||||||
loop_headers: &'a BitSet<BasicBlock>,
|
loop_headers: &'a BitSet<BasicBlock>,
|
||||||
|
@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(level = "trace", skip(self))]
|
#[instrument(level = "trace", skip(self))]
|
||||||
fn process_operand(
|
fn process_immediate(
|
||||||
&mut self,
|
&mut self,
|
||||||
bb: BasicBlock,
|
bb: BasicBlock,
|
||||||
lhs: PlaceIndex,
|
lhs: PlaceIndex,
|
||||||
rhs: &Operand<'tcx>,
|
rhs: ImmTy<'tcx>,
|
||||||
state: &mut State<ConditionSet<'a>>,
|
state: &mut State<ConditionSet<'a>>,
|
||||||
) -> Option<!> {
|
) -> Option<!> {
|
||||||
let register_opportunity = |c: Condition| {
|
let register_opportunity = |c: Condition| {
|
||||||
|
@ -341,13 +348,60 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
|
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let conditions = state.try_get_idx(lhs, self.map)?;
|
||||||
|
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
|
||||||
|
conditions.iter_matches(int).for_each(register_opportunity);
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(level = "trace", skip(self))]
|
||||||
|
fn process_operand(
|
||||||
|
&mut self,
|
||||||
|
bb: BasicBlock,
|
||||||
|
lhs: PlaceIndex,
|
||||||
|
rhs: &Operand<'tcx>,
|
||||||
|
state: &mut State<ConditionSet<'a>>,
|
||||||
|
) -> Option<!> {
|
||||||
match rhs {
|
match rhs {
|
||||||
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
|
||||||
Operand::Constant(constant) => {
|
Operand::Constant(constant) => {
|
||||||
let conditions = state.try_get_idx(lhs, self.map)?;
|
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
|
||||||
let constant =
|
self.map.for_each_projection_value(
|
||||||
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
|
lhs,
|
||||||
conditions.iter_matches(constant).for_each(register_opportunity);
|
constant,
|
||||||
|
&mut |elem, op| match elem {
|
||||||
|
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
|
||||||
|
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
|
||||||
|
TrackElem::Discriminant => {
|
||||||
|
let variant = self.ecx.read_discriminant(op).ok()?;
|
||||||
|
let discr_value =
|
||||||
|
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
|
||||||
|
Some(discr_value.into())
|
||||||
|
}
|
||||||
|
TrackElem::DerefLen => {
|
||||||
|
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
|
||||||
|
let len_usize = op.len(&self.ecx).ok()?;
|
||||||
|
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
|
||||||
|
Some(ImmTy::from_uint(len_usize, layout).into())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
&mut |place, op| {
|
||||||
|
if let Some(conditions) = state.try_get_idx(place, self.map)
|
||||||
|
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
|
||||||
|
&& let Some(imm) = imm.right()
|
||||||
|
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
|
||||||
|
{
|
||||||
|
conditions.iter_matches(int).for_each(|c: Condition| {
|
||||||
|
self.opportunities.push(ThreadingOpportunity {
|
||||||
|
chain: vec![bb],
|
||||||
|
target: c.target,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
// Transfer the conditions on the copied rhs.
|
// Transfer the conditions on the copied rhs.
|
||||||
Operand::Move(rhs) | Operand::Copy(rhs) => {
|
Operand::Move(rhs) | Operand::Copy(rhs) => {
|
||||||
|
@ -374,18 +428,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
// Below, `lhs` is the return value of `mutated_statement`,
|
// Below, `lhs` is the return value of `mutated_statement`,
|
||||||
// the place to which `conditions` apply.
|
// the place to which `conditions` apply.
|
||||||
|
|
||||||
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
|
|
||||||
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
|
|
||||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
|
|
||||||
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
|
|
||||||
Some(Operand::const_from_scalar(
|
|
||||||
self.tcx,
|
|
||||||
discr.ty,
|
|
||||||
scalar.into(),
|
|
||||||
rustc_span::DUMMY_SP,
|
|
||||||
))
|
|
||||||
};
|
|
||||||
|
|
||||||
match &stmt.kind {
|
match &stmt.kind {
|
||||||
// If we expect `discriminant(place) ?= A`,
|
// If we expect `discriminant(place) ?= A`,
|
||||||
// we have an opportunity if `variant_index ?= A`.
|
// we have an opportunity if `variant_index ?= A`.
|
||||||
|
@ -395,7 +437,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
|
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
|
||||||
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
|
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
|
||||||
// nothing.
|
// nothing.
|
||||||
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
|
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
|
||||||
let writes_discriminant = match enum_layout.variants {
|
let writes_discriminant = match enum_layout.variants {
|
||||||
Variants::Single { index } => {
|
Variants::Single { index } => {
|
||||||
assert_eq!(index, *variant_index);
|
assert_eq!(index, *variant_index);
|
||||||
|
@ -408,8 +450,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
} => *variant_index != untagged_variant,
|
} => *variant_index != untagged_variant,
|
||||||
};
|
};
|
||||||
if writes_discriminant {
|
if writes_discriminant {
|
||||||
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
|
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
|
||||||
self.process_operand(bb, discr_target, &discr, state)?;
|
self.process_immediate(bb, discr_target, discr, state)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
|
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
|
||||||
|
@ -440,10 +482,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
|
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
|
||||||
if let Some(discr_target) =
|
if let Some(discr_target) =
|
||||||
self.map.apply(lhs, TrackElem::Discriminant)
|
self.map.apply(lhs, TrackElem::Discriminant)
|
||||||
&& let Some(discr_value) =
|
&& let Ok(discr_value) = self
|
||||||
discriminant_for_variant(agg_ty, *variant_index)
|
.ecx
|
||||||
|
.discriminant_for_variant(agg_ty, *variant_index)
|
||||||
{
|
{
|
||||||
self.process_operand(bb, discr_target, &discr_value, state);
|
self.process_immediate(
|
||||||
|
bb,
|
||||||
|
discr_target,
|
||||||
|
discr_value,
|
||||||
|
state,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
|
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
|
||||||
}
|
}
|
||||||
|
@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
|
||||||
|
|
||||||
let discr = discr.place()?;
|
let discr = discr.place()?;
|
||||||
let discr_ty = discr.ty(self.body, self.tcx).ty;
|
let discr_ty = discr.ty(self.body, self.tcx).ty;
|
||||||
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
|
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
|
||||||
let conditions = state.try_get(discr.as_ref(), self.map)?;
|
let conditions = state.try_get(discr.as_ref(), self.map)?;
|
||||||
|
|
||||||
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
|
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
- // MIR for `aggregate` before JumpThreading
|
||||||
|
+ // MIR for `aggregate` after JumpThreading
|
||||||
|
|
||||||
|
fn aggregate(_1: u8) -> u8 {
|
||||||
|
debug x => _1;
|
||||||
|
let mut _0: u8;
|
||||||
|
let _2: u8;
|
||||||
|
let _3: u8;
|
||||||
|
let mut _4: (u8, u8);
|
||||||
|
let mut _5: bool;
|
||||||
|
let mut _6: u8;
|
||||||
|
scope 1 {
|
||||||
|
debug a => _2;
|
||||||
|
debug b => _3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb0: {
|
||||||
|
StorageLive(_4);
|
||||||
|
_4 = const _;
|
||||||
|
StorageLive(_2);
|
||||||
|
_2 = (_4.0: u8);
|
||||||
|
StorageLive(_3);
|
||||||
|
_3 = (_4.1: u8);
|
||||||
|
StorageDead(_4);
|
||||||
|
StorageLive(_5);
|
||||||
|
StorageLive(_6);
|
||||||
|
_6 = _2;
|
||||||
|
_5 = Eq(move _6, const 7_u8);
|
||||||
|
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||||
|
+ goto -> bb2;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb1: {
|
||||||
|
StorageDead(_6);
|
||||||
|
_0 = _3;
|
||||||
|
goto -> bb3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb2: {
|
||||||
|
StorageDead(_6);
|
||||||
|
_0 = _2;
|
||||||
|
goto -> bb3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb3: {
|
||||||
|
StorageDead(_5);
|
||||||
|
StorageDead(_3);
|
||||||
|
StorageDead(_2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
- // MIR for `aggregate` before JumpThreading
|
||||||
|
+ // MIR for `aggregate` after JumpThreading
|
||||||
|
|
||||||
|
fn aggregate(_1: u8) -> u8 {
|
||||||
|
debug x => _1;
|
||||||
|
let mut _0: u8;
|
||||||
|
let _2: u8;
|
||||||
|
let _3: u8;
|
||||||
|
let mut _4: (u8, u8);
|
||||||
|
let mut _5: bool;
|
||||||
|
let mut _6: u8;
|
||||||
|
scope 1 {
|
||||||
|
debug a => _2;
|
||||||
|
debug b => _3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb0: {
|
||||||
|
StorageLive(_4);
|
||||||
|
_4 = const _;
|
||||||
|
StorageLive(_2);
|
||||||
|
_2 = (_4.0: u8);
|
||||||
|
StorageLive(_3);
|
||||||
|
_3 = (_4.1: u8);
|
||||||
|
StorageDead(_4);
|
||||||
|
StorageLive(_5);
|
||||||
|
StorageLive(_6);
|
||||||
|
_6 = _2;
|
||||||
|
_5 = Eq(move _6, const 7_u8);
|
||||||
|
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
|
||||||
|
+ goto -> bb2;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb1: {
|
||||||
|
StorageDead(_6);
|
||||||
|
_0 = _3;
|
||||||
|
goto -> bb3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb2: {
|
||||||
|
StorageDead(_6);
|
||||||
|
_0 = _2;
|
||||||
|
goto -> bb3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bb3: {
|
||||||
|
StorageDead(_5);
|
||||||
|
StorageDead(_3);
|
||||||
|
StorageDead(_2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Verify that we can thread jumps when we assign from an aggregate constant.
|
||||||
|
fn aggregate(x: u8) -> u8 {
|
||||||
|
// CHECK-LABEL: fn aggregate(
|
||||||
|
// CHECK-NOT: switchInt(
|
||||||
|
|
||||||
|
const FOO: (u8, u8) = (5, 13);
|
||||||
|
|
||||||
|
let (a, b) = FOO;
|
||||||
|
if a == 7 {
|
||||||
|
b
|
||||||
|
} else {
|
||||||
|
a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
// CHECK-LABEL: fn main(
|
||||||
too_complex(Ok(0));
|
too_complex(Ok(0));
|
||||||
identity(Ok(0));
|
identity(Ok(0));
|
||||||
custom_discr(false);
|
custom_discr(false);
|
||||||
|
@ -464,6 +480,7 @@ fn main() {
|
||||||
mutable_ref();
|
mutable_ref();
|
||||||
renumbered_bb(true);
|
renumbered_bb(true);
|
||||||
disappearing_bb(7);
|
disappearing_bb(7);
|
||||||
|
aggregate(7);
|
||||||
}
|
}
|
||||||
|
|
||||||
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
|
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
|
||||||
|
@ -476,3 +493,4 @@ fn main() {
|
||||||
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
|
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
|
||||||
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
|
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
|
||||||
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
|
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
|
||||||
|
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue