Split coroutine desugaring kind from source
This commit is contained in:
parent
d6d7a93866
commit
004450506e
30 changed files with 448 additions and 239 deletions
|
@ -59,7 +59,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
|
|||
use rustc_errors::pluralize;
|
||||
use rustc_hir as hir;
|
||||
use rustc_hir::lang_items::LangItem;
|
||||
use rustc_hir::CoroutineKind;
|
||||
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
|
||||
use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
|
||||
use rustc_index::{Idx, IndexVec};
|
||||
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
|
||||
|
@ -254,10 +254,12 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
let source_info = SourceInfo::outermost(body.span);
|
||||
|
||||
let none_value = match self.coroutine_kind {
|
||||
CoroutineKind::Async(_) => span_bug!(body.span, "`Future`s are not fused inherently"),
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
|
||||
span_bug!(body.span, "`Future`s are not fused inherently")
|
||||
}
|
||||
CoroutineKind::Coroutine => span_bug!(body.span, "`Coroutine`s cannot be fused"),
|
||||
// `gen` continues return `None`
|
||||
CoroutineKind::Gen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
|
||||
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
|
||||
Rvalue::Aggregate(
|
||||
Box::new(AggregateKind::Adt(
|
||||
|
@ -271,7 +273,7 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
)
|
||||
}
|
||||
// `async gen` continues to return `Poll::Ready(None)`
|
||||
CoroutineKind::AsyncGen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
|
||||
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
|
||||
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
|
||||
let yield_ty = args.type_at(0);
|
||||
|
@ -316,7 +318,7 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
statements: &mut Vec<Statement<'tcx>>,
|
||||
) {
|
||||
let rvalue = match self.coroutine_kind {
|
||||
CoroutineKind::Async(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
|
||||
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
|
||||
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
|
||||
if is_return {
|
||||
|
@ -345,7 +347,7 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
)
|
||||
}
|
||||
}
|
||||
CoroutineKind::Gen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
|
||||
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
|
||||
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
|
||||
if is_return {
|
||||
|
@ -374,7 +376,7 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
)
|
||||
}
|
||||
}
|
||||
CoroutineKind::AsyncGen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
|
||||
if is_return {
|
||||
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
|
||||
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
|
||||
|
@ -1426,10 +1428,11 @@ fn create_coroutine_resume_function<'tcx>(
|
|||
|
||||
if can_return {
|
||||
let block = match coroutine_kind {
|
||||
CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine => {
|
||||
insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
|
||||
}
|
||||
CoroutineKind::AsyncGen(_) | CoroutineKind::Gen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
|
||||
| CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
|
||||
transform.insert_none_ret_block(body)
|
||||
}
|
||||
};
|
||||
|
@ -1443,7 +1446,7 @@ fn create_coroutine_resume_function<'tcx>(
|
|||
match coroutine_kind {
|
||||
// Iterator::next doesn't accept a pinned argument,
|
||||
// unlike for all other coroutine kinds.
|
||||
CoroutineKind::Gen(_) => {}
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
|
||||
_ => {
|
||||
make_coroutine_state_argument_pinned(tcx, body);
|
||||
}
|
||||
|
@ -1609,25 +1612,34 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
|||
}
|
||||
};
|
||||
|
||||
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
||||
let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_)));
|
||||
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
|
||||
let is_async_kind = matches!(
|
||||
body.coroutine_kind(),
|
||||
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
|
||||
);
|
||||
let is_async_gen_kind = matches!(
|
||||
body.coroutine_kind(),
|
||||
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
|
||||
);
|
||||
let is_gen_kind = matches!(
|
||||
body.coroutine_kind(),
|
||||
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
|
||||
);
|
||||
let new_ret_ty = match body.coroutine_kind().unwrap() {
|
||||
CoroutineKind::Async(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
|
||||
// Compute Poll<return_ty>
|
||||
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
||||
let poll_adt_ref = tcx.adt_def(poll_did);
|
||||
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
|
||||
Ty::new_adt(tcx, poll_adt_ref, poll_args)
|
||||
}
|
||||
CoroutineKind::Gen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
|
||||
// Compute Option<yield_ty>
|
||||
let option_did = tcx.require_lang_item(LangItem::Option, None);
|
||||
let option_adt_ref = tcx.adt_def(option_did);
|
||||
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
|
||||
Ty::new_adt(tcx, option_adt_ref, option_args)
|
||||
}
|
||||
CoroutineKind::AsyncGen(_) => {
|
||||
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
|
||||
// The yield ty is already `Poll<Option<yield_ty>>`
|
||||
old_yield_ty
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue