Implement async gen blocks

This commit is contained in:
Michael Goulet 2023-11-28 18:18:19 +00:00
parent a0cbc168c9
commit 96bb542a31
32 changed files with 563 additions and 54 deletions

View file

@ -119,9 +119,9 @@ fn fn_sig_for_fn_abi<'tcx>(
// unlike for all other coroutine kinds.
env_ty
}
hir::CoroutineKind::Async(_) | hir::CoroutineKind::Coroutine => {
Ty::new_adt(tcx, pin_adt_ref, pin_args)
}
hir::CoroutineKind::Async(_)
| hir::CoroutineKind::AsyncGen(_)
| hir::CoroutineKind::Coroutine => Ty::new_adt(tcx, pin_adt_ref, pin_args),
};
// The `FnSig` and the `ret_ty` here is for a coroutines main
@ -168,6 +168,30 @@ fn fn_sig_for_fn_abi<'tcx>(
(None, ret_ty)
}
hir::CoroutineKind::AsyncGen(_) => {
// The signature should be
// `AsyncIterator::poll_next(_, &mut Context<'_>) -> Poll<Option<Output>>`
assert_eq!(sig.return_ty, tcx.types.unit);
// Yield type is already `Poll<Option<yield_ty>>`
let ret_ty = sig.yield_ty;
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` which is used in codegen.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
let expected_adt =
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
};
}
let context_mut_ref = Ty::new_task_context(tcx);
(Some(context_mut_ref), ret_ty)
}
hir::CoroutineKind::Coroutine => {
// The signature should be `Coroutine::resume(_, Resume) -> CoroutineState<Yield, Return>`
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);

View file

@ -271,6 +271,21 @@ fn resolve_associated_item<'tcx>(
debug_assert!(tcx.defaultness(trait_item_id).has_value());
Some(Instance::new(trait_item_id, rcvr_args))
}
} else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() {
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
bug!()
};
if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next {
span_bug!(
tcx.def_span(coroutine_def_id),
"no definition for `{trait_ref}::{}` for built-in coroutine type",
tcx.item_name(trait_item_id)
)
}
// `AsyncIterator::poll_next` is generated by the compiler.
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
} else if Some(trait_ref.def_id) == lang_items.coroutine_trait() {
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
bug!()