1
Fork 0

Split coroutine desugaring kind from source

This commit is contained in:
Michael Goulet 2023-12-21 18:49:20 +00:00
parent d6d7a93866
commit 004450506e
30 changed files with 448 additions and 239 deletions

View file

@ -59,7 +59,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
use rustc_hir::CoroutineKind;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
@ -254,10 +254,12 @@ impl<'tcx> TransformVisitor<'tcx> {
let source_info = SourceInfo::outermost(body.span);
let none_value = match self.coroutine_kind {
CoroutineKind::Async(_) => span_bug!(body.span, "`Future`s are not fused inherently"),
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
span_bug!(body.span, "`Future`s are not fused inherently")
}
CoroutineKind::Coroutine => span_bug!(body.span, "`Coroutine`s cannot be fused"),
// `gen` continues return `None`
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
@ -271,7 +273,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
// `async gen` continues to return `Poll::Ready(None)`
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
let yield_ty = args.type_at(0);
@ -316,7 +318,7 @@ impl<'tcx> TransformVisitor<'tcx> {
statements: &mut Vec<Statement<'tcx>>,
) {
let rvalue = match self.coroutine_kind {
CoroutineKind::Async(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
if is_return {
@ -345,7 +347,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
}
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
if is_return {
@ -374,7 +376,7 @@ impl<'tcx> TransformVisitor<'tcx> {
)
}
}
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
if is_return {
let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
@ -1426,10 +1428,11 @@ fn create_coroutine_resume_function<'tcx>(
if can_return {
let block = match coroutine_kind {
CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine => {
insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
}
CoroutineKind::AsyncGen(_) | CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
| CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
transform.insert_none_ret_block(body)
}
};
@ -1443,7 +1446,7 @@ fn create_coroutine_resume_function<'tcx>(
match coroutine_kind {
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Gen(_) => {}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
_ => {
make_coroutine_state_argument_pinned(tcx, body);
}
@ -1609,25 +1612,34 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_)));
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
let is_async_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
);
let is_async_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
);
let is_gen_kind = matches!(
body.coroutine_kind(),
Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _))
);
let new_ret_ty = match body.coroutine_kind().unwrap() {
CoroutineKind::Async(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
Ty::new_adt(tcx, poll_adt_ref, poll_args)
}
CoroutineKind::Gen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
// Compute Option<yield_ty>
let option_did = tcx.require_lang_item(LangItem::Option, None);
let option_adt_ref = tcx.adt_def(option_did);
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
Ty::new_adt(tcx, option_adt_ref, option_args)
}
CoroutineKind::AsyncGen(_) => {
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
// The yield ty is already `Poll<Option<yield_ty>>`
old_yield_ty
}