Only split by-ref/by-move futures for async closures

This commit is contained in:
Michael Goulet 2024-02-13 15:29:50 +00:00
parent e760daa6a7
commit 05116c5c30
33 changed files with 119 additions and 432 deletions

View file

@ -67,45 +67,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
by_move_body.source = mir::MirSource {
instance: InstanceDef::CoroutineKindShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
target_kind: ty::ClosureKind::FnOnce,
},
promoted: None,
};
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
// If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
// This is actually just a copy of the by-ref body, but with a different self type.
// FIXME(async_closures): We could probably unify this with the by-ref body somehow.
if coroutine_kind == ty::ClosureKind::Fn {
let by_mut_coroutine_ty = Ty::new_coroutine(
tcx,
coroutine_def_id.to_def_id(),
ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args: args.as_coroutine().parent_args(),
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
resume_ty: args.as_coroutine().resume_ty(),
yield_ty: args.as_coroutine().yield_ty(),
return_ty: args.as_coroutine().return_ty(),
witness: args.as_coroutine().witness(),
tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
},
)
.args,
);
let mut by_mut_body = body.clone();
by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
by_mut_body.source = mir::MirSource {
instance: InstanceDef::CoroutineKindShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
target_kind: ty::ClosureKind::FnMut,
},
promoted: None,
};
body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
}
}
}

View file

@ -186,9 +186,6 @@ fn run_passes_inner<'tcx>(
if let Some(by_move_body) = coroutine.by_move_body.as_mut() {
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
}
if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() {
run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each);
}
}
}

View file

@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_index::{Idx, IndexVec};
@ -70,39 +70,13 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
}
ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind,
} => match target_kind {
ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
ty::ClosureKind::FnMut => {
// No need to optimize the body, it has already been optimized
// since we steal it from the `AsyncFn::call` body and just fix
// the return type.
return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
}
ty::ClosureKind::FnOnce => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}
},
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}
ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind {
ty::ClosureKind::Fn => unreachable!(),
ty::ClosureKind::FnMut => {
return tcx
.optimized_mir(coroutine_def_id)
.coroutine_by_mut_body()
.unwrap()
.clone();
}
ty::ClosureKind::FnOnce => {
return tcx
.optimized_mir(coroutine_def_id)
.coroutine_by_move_body()
.unwrap()
.clone();
}
},
ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
}
ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
@ -123,21 +97,11 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() {
coroutine_body.coroutine_drop().unwrap()
} else {
match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() {
ty::ClosureKind::Fn => {
unreachable!()
}
ty::ClosureKind::FnMut => coroutine_body
.coroutine_by_mut_body()
.unwrap()
.coroutine_drop()
.unwrap(),
ty::ClosureKind::FnOnce => coroutine_body
.coroutine_by_move_body()
.unwrap()
.coroutine_drop()
.unwrap(),
}
assert_eq!(
args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
ty::ClosureKind::FnOnce
);
coroutine_body.coroutine_by_move_body().unwrap().coroutine_drop().unwrap()
};
let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
@ -1112,7 +1076,6 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnOnce,
});
let body =
@ -1121,40 +1084,3 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
body
}
fn build_construct_coroutine_by_mut_shim<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_closure_def_id: DefId,
) -> Body<'tcx> {
let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
bug!();
};
let args = args.as_coroutine_closure();
body.local_decls[RETURN_PLACE].ty =
tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(coroutine_closure_def_id),
ty::ClosureKind::FnMut,
tcx.lifetimes.re_erased,
args.tupled_upvars_ty(),
args.coroutine_captures_by_ref_ty(),
)
}));
body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);
body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnMut,
});
body.pass_count = 0;
dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(()));
body
}