Rework the ByMoveBody shim to actually work correctly
This commit is contained in:
parent
1921968cc5
commit
3674032eb2
5 changed files with 336 additions and 36 deletions
|
@ -60,14 +60,13 @@
|
||||||
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
|
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
|
||||||
//! we use this "by move" body instead.
|
//! we use this "by move" body instead.
|
||||||
|
|
||||||
use itertools::Itertools;
|
use rustc_data_structures::unord::UnordMap;
|
||||||
|
|
||||||
use rustc_data_structures::unord::UnordSet;
|
|
||||||
use rustc_hir as hir;
|
use rustc_hir as hir;
|
||||||
|
use rustc_middle::hir::place::{Projection, ProjectionKind};
|
||||||
use rustc_middle::mir::visit::MutVisitor;
|
use rustc_middle::mir::visit::MutVisitor;
|
||||||
use rustc_middle::mir::{self, dump_mir, MirPass};
|
use rustc_middle::mir::{self, dump_mir, MirPass};
|
||||||
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
|
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
|
||||||
use rustc_target::abi::FieldIdx;
|
use rustc_target::abi::{FieldIdx, VariantIdx};
|
||||||
|
|
||||||
pub struct ByMoveBody;
|
pub struct ByMoveBody;
|
||||||
|
|
||||||
|
@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||||
.tuple_fields()
|
.tuple_fields()
|
||||||
.len();
|
.len();
|
||||||
|
|
||||||
let mut by_ref_fields = UnordSet::default();
|
let mut field_remapping = UnordMap::default();
|
||||||
for (idx, (coroutine_capture, parent_capture)) in tcx
|
|
||||||
|
let mut parent_captures =
|
||||||
|
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
|
||||||
|
|
||||||
|
for (child_field_idx, child_capture) in tcx
|
||||||
.closure_captures(coroutine_def_id)
|
.closure_captures(coroutine_def_id)
|
||||||
.iter()
|
.iter()
|
||||||
|
.copied()
|
||||||
// By construction we capture all the args first.
|
// By construction we capture all the args first.
|
||||||
.skip(num_args)
|
.skip(num_args)
|
||||||
.zip_eq(tcx.closure_captures(parent_def_id))
|
|
||||||
.enumerate()
|
.enumerate()
|
||||||
{
|
{
|
||||||
// This upvar is captured by-move from the parent closure, but by-ref
|
loop {
|
||||||
// from the inner async block. That means that it's being borrowed from
|
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
|
||||||
// the outer closure body -- we need to change the coroutine to take the
|
bug!("we ran out of parent captures!")
|
||||||
// upvar by value.
|
};
|
||||||
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
|
|
||||||
assert_ne!(
|
|
||||||
coroutine_kind,
|
|
||||||
ty::ClosureKind::FnOnce,
|
|
||||||
"`FnOnce` coroutine-closures return coroutines that capture from \
|
|
||||||
their body; it will always result in a borrowck error!"
|
|
||||||
);
|
|
||||||
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we're actually talking about the same capture.
|
if !std::iter::zip(
|
||||||
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
|
&child_capture.place.projections,
|
||||||
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
|
&parent_capture.place.projections,
|
||||||
|
)
|
||||||
|
.all(|(child, parent)| child.kind == parent.kind)
|
||||||
|
{
|
||||||
|
// Skip this field.
|
||||||
|
let _ = parent_captures.next().unwrap();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let child_precise_captures =
|
||||||
|
&child_capture.place.projections[parent_capture.place.projections.len()..];
|
||||||
|
|
||||||
|
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
|
||||||
|
if needs_deref {
|
||||||
|
assert_ne!(
|
||||||
|
coroutine_kind,
|
||||||
|
ty::ClosureKind::FnOnce,
|
||||||
|
"`FnOnce` coroutine-closures return coroutines that capture from \
|
||||||
|
their body; it will always result in a borrowck error!"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut parent_capture_ty = parent_capture.place.ty();
|
||||||
|
parent_capture_ty = match parent_capture.info.capture_kind {
|
||||||
|
ty::UpvarCapture::ByValue => parent_capture_ty,
|
||||||
|
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
|
||||||
|
tcx,
|
||||||
|
tcx.lifetimes.re_erased,
|
||||||
|
parent_capture_ty,
|
||||||
|
kind.to_mutbl_lossy(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
field_remapping.insert(
|
||||||
|
FieldIdx::from_usize(child_field_idx + num_args),
|
||||||
|
(
|
||||||
|
FieldIdx::from_usize(parent_field_idx + num_args),
|
||||||
|
parent_capture_ty,
|
||||||
|
needs_deref,
|
||||||
|
child_precise_captures,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if coroutine_kind == ty::ClosureKind::FnOnce {
|
||||||
|
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let by_move_coroutine_ty = tcx
|
let by_move_coroutine_ty = tcx
|
||||||
|
@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut by_move_body = body.clone();
|
let mut by_move_body = body.clone();
|
||||||
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
|
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
|
||||||
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
|
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
|
||||||
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
|
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
|
||||||
coroutine_def_id: coroutine_def_id.to_def_id(),
|
coroutine_def_id: coroutine_def_id.to_def_id(),
|
||||||
|
@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
|
||||||
|
|
||||||
struct MakeByMoveBody<'tcx> {
|
struct MakeByMoveBody<'tcx> {
|
||||||
tcx: TyCtxt<'tcx>,
|
tcx: TyCtxt<'tcx>,
|
||||||
by_ref_fields: UnordSet<FieldIdx>,
|
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
|
||||||
by_move_coroutine_ty: Ty<'tcx>,
|
by_move_coroutine_ty: Ty<'tcx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
|
||||||
location: mir::Location,
|
location: mir::Location,
|
||||||
) {
|
) {
|
||||||
if place.local == ty::CAPTURE_STRUCT_LOCAL
|
if place.local == ty::CAPTURE_STRUCT_LOCAL
|
||||||
&& let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
|
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
|
||||||
place.projection.split_first()
|
place.projection.split_first()
|
||||||
&& self.by_ref_fields.contains(&idx)
|
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
|
||||||
|
self.field_remapping.get(&idx)
|
||||||
{
|
{
|
||||||
let (begin, end) = projection.split_first().unwrap();
|
let final_deref = if needs_deref {
|
||||||
// FIXME(async_closures): I'm actually a bit surprised to see that we always
|
let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
|
||||||
// initially deref the by-ref upvars. If this is not actually true, then we
|
bug!();
|
||||||
// will at least get an ICE that explains why this isn't true :^)
|
};
|
||||||
assert_eq!(*begin, mir::ProjectionElem::Deref);
|
rest
|
||||||
// Peel one ref off of the ty.
|
} else {
|
||||||
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
|
projection
|
||||||
|
};
|
||||||
|
|
||||||
|
let additional_projections =
|
||||||
|
additional_projections.iter().map(|elem| match elem.kind {
|
||||||
|
ProjectionKind::Deref => mir::ProjectionElem::Deref,
|
||||||
|
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
|
||||||
|
mir::ProjectionElem::Field(idx, elem.ty)
|
||||||
|
}
|
||||||
|
_ => unreachable!("precise captures only through fields and derefs"),
|
||||||
|
});
|
||||||
|
|
||||||
*place = mir::Place {
|
*place = mir::Place {
|
||||||
local: place.local,
|
local: place.local,
|
||||||
projection: self.tcx.mk_place_elems_from_iter(
|
projection: self.tcx.mk_place_elems_from_iter(
|
||||||
[mir::ProjectionElem::Field(idx, peeled_ty)]
|
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.chain(end.iter().copied()),
|
.chain(additional_projections)
|
||||||
|
.chain(final_deref.iter().copied()),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
untouched
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
|
@ -0,0 +1,29 @@
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
fixed
|
||||||
|
after await
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
untouched
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
|
@ -0,0 +1,29 @@
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
fixed
|
||||||
|
after await
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
after await
|
||||||
|
fixed
|
||||||
|
untouched
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
||||||
|
|
||||||
|
after call
|
||||||
|
drop first
|
||||||
|
after await
|
||||||
|
uncaptured
|
157
tests/ui/async-await/async-closures/precise-captures.rs
Normal file
157
tests/ui/async-await/async-closures/precise-captures.rs
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
//@ aux-build:block-on.rs
|
||||||
|
//@ edition:2021
|
||||||
|
//@ run-pass
|
||||||
|
//@ check-run-results
|
||||||
|
//@ revisions: call call_once force_once
|
||||||
|
|
||||||
|
// call - Call the closure regularly.
|
||||||
|
// call_once - Call the closure w/ `async FnOnce`, so exercising the by_move shim.
|
||||||
|
// force_once - Force the closure mode to `FnOnce`, so exercising what was fixed
|
||||||
|
// in <https://github.com/rust-lang/rust/pull/123350>.
|
||||||
|
|
||||||
|
#![feature(async_closure)]
|
||||||
|
#![allow(unused_mut)]
|
||||||
|
|
||||||
|
extern crate block_on;
|
||||||
|
|
||||||
|
#[cfg(any(call, force_once))]
|
||||||
|
macro_rules! call {
|
||||||
|
($c:expr) => { ($c)() }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(call_once)]
|
||||||
|
async fn call_once(f: impl async FnOnce()) {
|
||||||
|
f().await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(call_once)]
|
||||||
|
macro_rules! call {
|
||||||
|
($c:expr) => { call_once($c) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(force_once))]
|
||||||
|
macro_rules! guidance {
|
||||||
|
($c:expr) => { $c }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(force_once)]
|
||||||
|
fn infer_fnonce(c: impl async FnOnce()) -> impl async FnOnce() { c }
|
||||||
|
|
||||||
|
#[cfg(force_once)]
|
||||||
|
macro_rules! guidance {
|
||||||
|
($c:expr) => { infer_fnonce($c) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Drop(&'static str);
|
||||||
|
|
||||||
|
impl std::ops::Drop for Drop {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
println!("{}", self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct S {
|
||||||
|
a: i32,
|
||||||
|
b: Drop,
|
||||||
|
c: Drop,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn async_main() {
|
||||||
|
// Precise capture struct
|
||||||
|
{
|
||||||
|
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||||
|
let mut c = guidance!(async || {
|
||||||
|
s.a = 2;
|
||||||
|
let w = &mut s.b;
|
||||||
|
w.0 = "fixed";
|
||||||
|
});
|
||||||
|
s.c.0 = "uncaptured";
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Precise capture &mut struct
|
||||||
|
{
|
||||||
|
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||||
|
let mut c = guidance!(async || {
|
||||||
|
s.a = 2;
|
||||||
|
let w = &mut s.b;
|
||||||
|
w.0 = "fixed";
|
||||||
|
});
|
||||||
|
s.c.0 = "uncaptured";
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Precise capture struct by move
|
||||||
|
{
|
||||||
|
let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||||
|
let mut c = guidance!(async move || {
|
||||||
|
s.a = 2;
|
||||||
|
let w = &mut s.b;
|
||||||
|
w.0 = "fixed";
|
||||||
|
});
|
||||||
|
s.c.0 = "uncaptured";
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Precise capture &mut struct by move
|
||||||
|
{
|
||||||
|
let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
|
||||||
|
let mut c = guidance!(async move || {
|
||||||
|
s.a = 2;
|
||||||
|
let w = &mut s.b;
|
||||||
|
w.0 = "fixed";
|
||||||
|
});
|
||||||
|
// `s` is still captured fully as `&mut S`.
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Precise capture struct, consume field
|
||||||
|
{
|
||||||
|
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
|
||||||
|
let c = guidance!(async move || {
|
||||||
|
// s.a = 2; // FIXME(async_closures): Figure out why this fails
|
||||||
|
drop(s.b);
|
||||||
|
});
|
||||||
|
s.c.0 = "uncaptured";
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
// Precise capture struct by move, consume field
|
||||||
|
{
|
||||||
|
let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
|
||||||
|
let c = guidance!(async move || {
|
||||||
|
// s.a = 2; // FIXME(async_closures): Figure out why this fails
|
||||||
|
drop(s.b);
|
||||||
|
});
|
||||||
|
s.c.0 = "uncaptured";
|
||||||
|
let fut = call!(c);
|
||||||
|
println!("after call");
|
||||||
|
fut.await;
|
||||||
|
println!("after await");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
block_on::block_on(async_main());
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue