1
Fork 0

Add hir::GeneratorKind::Gen

This commit is contained in:
Oli Scherer 2023-10-20 19:21:24 +00:00
parent a61cf673cd
commit 14423080f1
12 changed files with 99 additions and 22 deletions

View file

@ -712,7 +712,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let full_span = expr.span.to(await_kw_span); let full_span = expr.span.to(await_kw_span);
match self.coroutine_kind { match self.coroutine_kind {
Some(hir::CoroutineKind::Async(_)) => {} Some(hir::CoroutineKind::Async(_)) => {}
Some(hir::CoroutineKind::Coroutine) | None => { Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks { self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
await_kw_span, await_kw_span,
item_span: self.current_item, item_span: self.current_item,
@ -936,8 +936,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
} }
Some(movability) Some(movability)
} }
Some(hir::CoroutineKind::Async(_)) => { Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => {
panic!("non-`async` closure body turned `async` during lowering"); panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
} }
None => { None => {
if movability == Movability::Static { if movability == Movability::Static {
@ -1446,6 +1446,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> { fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
match self.coroutine_kind { match self.coroutine_kind {
Some(hir::CoroutineKind::Coroutine) => {} Some(hir::CoroutineKind::Coroutine) => {}
Some(hir::CoroutineKind::Gen(_)) => {}
Some(hir::CoroutineKind::Async(_)) => { Some(hir::CoroutineKind::Async(_)) => {
self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }); self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span });
} }

View file

@ -2505,6 +2505,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
}; };
let kind = match use_span.coroutine_kind() { let kind = match use_span.coroutine_kind() {
Some(coroutine_kind) => match coroutine_kind { Some(coroutine_kind) => match coroutine_kind {
CoroutineKind::Gen(kind) => match kind {
CoroutineSource::Block => "gen block",
CoroutineSource::Closure => "gen closure",
_ => bug!("gen block/closure expected, but gen function found."),
},
CoroutineKind::Async(async_kind) => match async_kind { CoroutineKind::Async(async_kind) => match async_kind {
CoroutineSource::Block => "async block", CoroutineSource::Block => "async block",
CoroutineSource::Closure => "async closure", CoroutineSource::Closure => "async closure",

View file

@ -698,6 +698,20 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
" of async function" " of async function"
} }
}, },
Some(hir::CoroutineKind::Gen(gen)) => match gen {
hir::CoroutineSource::Block => " of gen block",
hir::CoroutineSource::Closure => " of gen closure",
hir::CoroutineSource::Fn => {
let parent_item =
hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
let output = &parent_item
.fn_decl()
.expect("coroutine lowered from gen fn should be in fn")
.output;
span = output.span();
" of gen function"
}
},
Some(hir::CoroutineKind::Coroutine) => " of coroutine", Some(hir::CoroutineKind::Coroutine) => " of coroutine",
None => " of closure", None => " of closure",
}; };

View file

@ -560,6 +560,9 @@ pub fn push_item_name(tcx: TyCtxt<'_>, def_id: DefId, qualified: bool, output: &
fn coroutine_kind_label(coroutine_kind: Option<CoroutineKind>) -> &'static str { fn coroutine_kind_label(coroutine_kind: Option<CoroutineKind>) -> &'static str {
match coroutine_kind { match coroutine_kind {
Some(CoroutineKind::Gen(CoroutineSource::Block)) => "gen_block",
Some(CoroutineKind::Gen(CoroutineSource::Closure)) => "gen_closure",
Some(CoroutineKind::Gen(CoroutineSource::Fn)) => "gen_fn",
Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block", Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block",
Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure", Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure",
Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn", Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn",

View file

@ -1513,6 +1513,9 @@ pub enum CoroutineKind {
/// An explicit `async` block or the body of an async function. /// An explicit `async` block or the body of an async function.
Async(CoroutineSource), Async(CoroutineSource),
/// An explicit `gen` block or the body of a `gen` function.
Gen(CoroutineSource),
/// A coroutine literal created via a `yield` inside a closure. /// A coroutine literal created via a `yield` inside a closure.
Coroutine, Coroutine,
} }
@ -1529,6 +1532,14 @@ impl fmt::Display for CoroutineKind {
k.fmt(f) k.fmt(f)
} }
CoroutineKind::Coroutine => f.write_str("coroutine"), CoroutineKind::Coroutine => f.write_str("coroutine"),
CoroutineKind::Gen(k) => {
if f.alternate() {
f.write_str("`gen` ")?;
} else {
f.write_str("gen ")?
}
k.fmt(f)
}
} }
} }
} }
@ -2242,6 +2253,7 @@ impl From<CoroutineKind> for YieldSource {
// Guess based on the kind of the current coroutine. // Guess based on the kind of the current coroutine.
CoroutineKind::Coroutine => Self::Yield, CoroutineKind::Coroutine => Self::Yield,
CoroutineKind::Async(_) => Self::Await { expr: None }, CoroutineKind::Async(_) => Self::Await { expr: None },
CoroutineKind::Gen(_) => Self::Yield,
} }
} }
} }

View file

@ -58,15 +58,16 @@ pub(super) fn check_fn<'a, 'tcx>(
if let Some(kind) = body.coroutine_kind if let Some(kind) = body.coroutine_kind
&& can_be_coroutine.is_some() && can_be_coroutine.is_some()
{ {
let yield_ty = if kind == hir::CoroutineKind::Coroutine { let yield_ty = match kind {
let yield_ty = fcx.next_ty_var(TypeVariableOrigin { hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => {
kind: TypeVariableOriginKind::TypeInference, let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
span, kind: TypeVariableOriginKind::TypeInference,
}); span,
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType); });
yield_ty fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
} else { yield_ty
Ty::new_unit(tcx) }
hir::CoroutineKind::Async(..) => Ty::new_unit(tcx),
}; };
// Resume type defaults to `()` if the coroutine has no argument. // Resume type defaults to `()` if the coroutine has no argument.

View file

@ -148,8 +148,14 @@ impl<O> AssertKind<O> {
RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero", RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero",
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
bug!("`gen fn` should just keep returning `None` after the first time")
}
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
bug!("`gen fn` should just keep returning `None` after panicking")
}
BoundsCheck { .. } | MisalignedPointerDereference { .. } => { BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
bug!("Unexpected AssertKind") bug!("Unexpected AssertKind")
} }
@ -236,10 +242,14 @@ impl<O> AssertKind<O> {
DivisionByZero(_) => middle_assert_divide_by_zero, DivisionByZero(_) => middle_assert_divide_by_zero,
RemainderByZero(_) => middle_assert_remainder_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero,
ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return,
// FIXME(gen_blocks): custom error message for `gen` blocks
ResumedAfterReturn(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_return,
ResumedAfterReturn(CoroutineKind::Coroutine) => { ResumedAfterReturn(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_return middle_assert_coroutine_resume_after_return
} }
ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic, ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic,
// FIXME(gen_blocks): custom error message for `gen` blocks
ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_panic,
ResumedAfterPanic(CoroutineKind::Coroutine) => { ResumedAfterPanic(CoroutineKind::Coroutine) => {
middle_assert_coroutine_resume_after_panic middle_assert_coroutine_resume_after_panic
} }

View file

@ -749,6 +749,7 @@ impl<'tcx> TyCtxt<'tcx> {
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() { DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "async closure", rustc_hir::CoroutineKind::Async(..) => "async closure",
rustc_hir::CoroutineKind::Coroutine => "coroutine", rustc_hir::CoroutineKind::Coroutine => "coroutine",
rustc_hir::CoroutineKind::Gen(..) => "gen closure",
}, },
_ => def_kind.descr(def_id), _ => def_kind.descr(def_id),
} }
@ -766,6 +767,7 @@ impl<'tcx> TyCtxt<'tcx> {
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() { DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "an", rustc_hir::CoroutineKind::Async(..) => "an",
rustc_hir::CoroutineKind::Coroutine => "a", rustc_hir::CoroutineKind::Coroutine => "a",
rustc_hir::CoroutineKind::Gen(..) => "a",
}, },
_ => def_kind.article(), _ => def_kind.article(),
} }

View file

@ -880,18 +880,28 @@ impl<'tcx> Stable<'tcx> for mir::AggregateKind<'tcx> {
} }
} }
impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineSource {
type T = stable_mir::mir::CoroutineSource;
fn stable(&self, _: &mut Tables<'tcx>) -> Self::T {
use rustc_hir::CoroutineSource;
match self {
CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block,
CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure,
CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn,
}
}
}
impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind { impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind {
type T = stable_mir::mir::CoroutineKind; type T = stable_mir::mir::CoroutineKind;
fn stable(&self, _: &mut Tables<'tcx>) -> Self::T { fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
use rustc_hir::{CoroutineKind, CoroutineSource}; use rustc_hir::CoroutineKind;
match self { match self {
CoroutineKind::Async(async_gen) => { CoroutineKind::Async(source) => {
let async_gen = match async_gen { stable_mir::mir::CoroutineKind::Async(source.stable(tables))
CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block, }
CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure, CoroutineKind::Gen(source) => {
CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn, stable_mir::mir::CoroutineKind::Gen(source.stable(tables))
};
stable_mir::mir::CoroutineKind::Async(async_gen)
} }
CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine, CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine,
} }

View file

@ -2425,6 +2425,21 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
CoroutineKind::Async(CoroutineSource::Closure) => { CoroutineKind::Async(CoroutineSource::Closure) => {
format!("future created by async closure is not {trait_name}") format!("future created by async closure is not {trait_name}")
} }
CoroutineKind::Gen(CoroutineSource::Fn) => self
.tcx
.parent(coroutine_did)
.as_local()
.map(|parent_did| hir.local_def_id_to_hir_id(parent_did))
.and_then(|parent_hir_id| hir.opt_name(parent_hir_id))
.map(|name| {
format!("iterator returned by `{name}` is not {trait_name}")
})?,
CoroutineKind::Gen(CoroutineSource::Block) => {
format!("iterator created by gen block is not {trait_name}")
}
CoroutineKind::Gen(CoroutineSource::Closure) => {
format!("iterator created by gen closure is not {trait_name}")
}
}) })
}) })
.unwrap_or_else(|| format!("{future_or_coroutine} is not {trait_name}")); .unwrap_or_else(|| format!("{future_or_coroutine} is not {trait_name}"));
@ -2905,7 +2920,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
} }
ObligationCauseCode::SizedCoroutineInterior(coroutine_def_id) => { ObligationCauseCode::SizedCoroutineInterior(coroutine_def_id) => {
let what = match self.tcx.coroutine_kind(coroutine_def_id) { let what = match self.tcx.coroutine_kind(coroutine_def_id) {
None | Some(hir::CoroutineKind::Coroutine) => "yield", None | Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) => "yield",
Some(hir::CoroutineKind::Async(..)) => "await", Some(hir::CoroutineKind::Async(..)) => "await",
}; };
err.note(format!( err.note(format!(

View file

@ -1614,6 +1614,9 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block", hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function", hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function",
hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure", hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure",
hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block",
hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function",
hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure",
}) })
} }

View file

@ -137,6 +137,7 @@ pub enum UnOp {
pub enum CoroutineKind { pub enum CoroutineKind {
Async(CoroutineSource), Async(CoroutineSource),
Coroutine, Coroutine,
Gen(CoroutineSource),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]