In rustc_mir_tranform, iterate over index newtypes instead of ints

This commit is contained in:
Yotam Ofek 2025-04-11 14:26:26 +00:00
parent 69b3959afe
commit c36e8fcc3c
6 changed files with 46 additions and 53 deletions

View file

@ -257,6 +257,13 @@ impl Parse for Newtype {
}
}
impl std::ops::AddAssign<usize> for #name {
#[inline]
fn add_assign(&mut self, other: usize) {
*self = *self + other;
}
}
impl rustc_index::Idx for #name {
#[inline]
fn new(value: usize) -> Self {

View file

@ -547,7 +547,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
for bb in START_BLOCK..body.basic_blocks.next_index() {
for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
@ -556,11 +556,11 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind() {
if def_id == get_context_def_id {
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
@ -1057,7 +1057,7 @@ fn insert_switch<'tcx>(
let blocks = body.basic_blocks_mut().iter_mut();
for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) {
*target = BasicBlock::new(target.index() + 1);
*target += 1;
}
}
@ -1209,14 +1209,8 @@ fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::Typing
}
// If there's a return terminator the function may return.
for block in body.basic_blocks.iter() {
if let TerminatorKind::Return = block.terminator().kind {
return true;
}
}
body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
// Otherwise the function can't return.
false
}
fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
@ -1293,12 +1287,12 @@ fn create_coroutine_resume_function<'tcx>(
kind: TerminatorKind::Goto { target: poison_block },
};
}
} else if !block.is_cleanup {
} else if !block.is_cleanup
// Any terminators that *can* unwind but don't have an unwind target set are also
// pointed at our poisoning block (unless they're part of the cleanup path).
if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() {
*unwind = UnwindAction::Cleanup(poison_block);
}
&& let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
{
*unwind = UnwindAction::Cleanup(poison_block);
}
}
}
@ -1340,12 +1334,14 @@ fn create_coroutine_resume_function<'tcx>(
make_coroutine_state_argument_indirect(tcx, body);
match transform.coroutine_kind {
CoroutineKind::Coroutine(_)
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
{
make_coroutine_state_argument_pinned(tcx, body);
}
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
_ => {
make_coroutine_state_argument_pinned(tcx, body);
}
}
// Make sure we remove dead blocks to remove
@ -1408,8 +1404,7 @@ fn create_cases<'tcx>(
let mut statements = Vec::new();
// Create StorageLive instructions for locals with live storage
for i in 0..(body.local_decls.len()) {
let l = Local::new(i);
for l in body.local_decls.indices() {
let needs_storage_live = point.storage_liveness.contains(l)
&& !transform.remap.contains(l)
&& !transform.always_live_locals.contains(l);
@ -1535,15 +1530,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
let coroutine_kind = body.coroutine_kind().unwrap();
// Get the discriminant type and args which typeck computed
let (discr_ty, movable) = match *coroutine_ty.kind() {
ty::Coroutine(_, args) => {
let args = args.as_coroutine();
(args.discr_ty(tcx), coroutine_kind.movability() == hir::Movability::Movable)
}
_ => {
tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
}
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
};
let discr_ty = args.as_coroutine().discr_ty(tcx);
let new_ret_ty = match coroutine_kind {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
@ -1610,6 +1600,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
let always_live_locals = always_storage_live_locals(body);
let movable = coroutine_kind.movability() == hir::Movability::Movable;
let liveness_info =
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);

View file

@ -103,9 +103,8 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
let mut should_cleanup = false;
// Also consider newly generated bbs in the same pass
for i in 0..body.basic_blocks.len() {
for parent in body.basic_blocks.indices() {
let bbs = &*body.basic_blocks;
let parent = BasicBlock::from_usize(i);
let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");

View file

@ -20,13 +20,11 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let typing_env = body.typing_env(tcx);
let mut should_cleanup = false;
for i in 0..body.basic_blocks.len() {
let bbs = &*body.basic_blocks;
let bb_idx = BasicBlock::from_usize(i);
match bbs[bb_idx].terminator().kind {
for bb_idx in body.basic_blocks.indices() {
match &body.basic_blocks[bb_idx].terminator().kind {
TerminatorKind::SwitchInt {
discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)),
ref targets,
discr: Operand::Copy(_) | Operand::Move(_),
targets,
..
// We require that the possible target blocks don't contain this block.
} if !targets.all_targets().contains(&bb_idx) => {}
@ -66,9 +64,10 @@ trait SimplifyMatch<'tcx> {
typing_env: ty::TypingEnv<'tcx>,
) -> Option<()> {
let bbs = &body.basic_blocks;
let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
_ => unreachable!(),
let TerminatorKind::SwitchInt { discr, targets, .. } =
&bbs[switch_bb_idx].terminator().kind
else {
unreachable!();
};
let discr_ty = discr.ty(body.local_decls(), tcx);

View file

@ -18,19 +18,17 @@ impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators {
// find basic blocks with no statement and a return terminator
let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len());
let bbs = body.basic_blocks_mut();
for idx in bbs.indices() {
if bbs[idx].statements.is_empty()
&& bbs[idx].terminator().kind == TerminatorKind::Return
{
for (idx, bb) in bbs.iter_enumerated() {
if bb.statements.is_empty() && bb.terminator().kind == TerminatorKind::Return {
bbs_simple_returns.insert(idx);
}
}
for bb in bbs {
if let TerminatorKind::Goto { target } = bb.terminator().kind {
if bbs_simple_returns.contains(target) {
bb.terminator_mut().kind = TerminatorKind::Return;
}
if let TerminatorKind::Goto { target } = bb.terminator().kind
&& bbs_simple_returns.contains(target)
{
bb.terminator_mut().kind = TerminatorKind::Return;
}
}

View file

@ -221,12 +221,11 @@ impl<'a, 'tcx> CfgChecker<'a, 'tcx> {
// Check for cycles
let mut stack = FxHashSet::default();
for i in 0..parent.len() {
let mut bb = BasicBlock::from_usize(i);
for (mut bb, parent) in parent.iter_enumerated_mut() {
stack.clear();
stack.insert(bb);
loop {
let Some(parent) = parent[bb].take() else { break };
let Some(parent) = parent.take() else { break };
let no_cycle = stack.insert(parent);
if !no_cycle {
self.fail(