1
Fork 0

Construct body for by-move coroutine closure output

This commit is contained in:
Michael Goulet 2024-01-24 23:38:33 +00:00
parent fc4fff4038
commit 427896dd7e
24 changed files with 233 additions and 15 deletions

View file

@ -50,6 +50,9 @@
//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
//! Otherwise it drops all the values in scope at the last suspension point.
mod by_move_body;
pub use by_move_body::ByMoveBody;
use crate::abort_unwinding_calls;
use crate::deref_separator::deref_finder;
use crate::errors;

View file

@ -0,0 +1,108 @@
use rustc_data_structures::fx::FxIndexSet;
use rustc_hir as hir;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, MirPass};
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
use rustc_target::abi::FieldIdx;
pub struct ByMoveBody;
impl<'tcx> MirPass<'tcx> for ByMoveBody {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
let Some(coroutine_def_id) = body.source.def_id().as_local() else {
return;
};
let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
tcx.coroutine_kind(coroutine_def_id)
else {
return;
};
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
return;
}
let mut by_ref_fields = FxIndexSet::default();
let by_move_upvars = Ty::new_tup_from_iter(
tcx,
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
if capture.is_by_ref() {
by_ref_fields.insert(FieldIdx::from_usize(idx));
}
capture.place.ty()
}),
);
let by_move_coroutine_ty = Ty::new_coroutine(
tcx,
coroutine_def_id.to_def_id(),
ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args: args.as_coroutine().parent_args(),
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
resume_ty: args.as_coroutine().resume_ty(),
yield_ty: args.as_coroutine().yield_ty(),
return_ty: args.as_coroutine().return_ty(),
witness: args.as_coroutine().witness(),
tupled_upvars_ty: by_move_upvars,
},
)
.args,
);
let mut by_move_body = body.clone();
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
by_move_body.source = mir::MirSource {
instance: InstanceDef::CoroutineByMoveShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
},
promoted: None,
};
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
}
}
struct MakeByMoveBody<'tcx> {
tcx: TyCtxt<'tcx>,
by_ref_fields: FxIndexSet<FieldIdx>,
by_move_coroutine_ty: Ty<'tcx>,
}
impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn visit_place(
&mut self,
place: &mut mir::Place<'tcx>,
context: mir::visit::PlaceContext,
location: mir::Location,
) {
if place.local == ty::CAPTURE_STRUCT_LOCAL
&& !place.projection.is_empty()
&& let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
&& self.by_ref_fields.contains(&idx)
{
let (begin, end) = place.projection[1..].split_first().unwrap();
assert_eq!(*begin, mir::ProjectionElem::Deref);
*place = mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems_from_iter(
[mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
.into_iter()
.chain(end.iter().copied()),
),
};
}
self.super_place(place, context, location);
}
fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
if local == ty::CAPTURE_STRUCT_LOCAL {
local_decl.ty = self.by_move_coroutine_ty;
}
}
}

View file

@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> {
| InstanceDef::FnPtrShim(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::CoroutineByMoveShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::CloneShim(..)
| InstanceDef::ThreadLocalShim(..)

View file

@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
| InstanceDef::FnPtrShim(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::CoroutineByMoveShim { .. }
| InstanceDef::ThreadLocalShim { .. }
| InstanceDef::CloneShim(..) => {}

View file

@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> {
&Lint(check_packed_ref::CheckPackedRef),
&Lint(check_const_item_mutation::CheckConstItemMutation),
&Lint(function_item_references::FunctionItemReferences),
// If this is an async closure's output coroutine, generate
// by-move and by-mut bodies if needed. We do this first so
// they can be optimized in lockstep with their parent bodies.
&coroutine::ByMoveBody,
// What we need to do constant evaluation.
&simplify::SimplifyCfg::Initial,
&rustc_peek::SanityCheck, // Just a lint

View file

@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>(
body.pass_count = 1;
}
if let Some(coroutine) = body.coroutine.as_mut()
&& let Some(by_move_body) = coroutine.by_move_body.as_mut()
{
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
}
}
pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) {

View file

@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
}
},
ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => {
return tcx
.optimized_mir(coroutine_def_id)
.coroutine
.as_ref()
.unwrap()
.by_move_body
.as_ref()
.unwrap()
.clone();
}
ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?