1
Fork 0

Basic generators work

This commit is contained in:
Oli Scherer 2023-10-20 23:19:40 +00:00
parent 998a816106
commit c892b28c02
4 changed files with 98 additions and 34 deletions

View file

@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
struct TransformVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
is_async_kind: bool,
coroutine_kind: hir::CoroutineKind,
state_adt_ref: AdtDef<'tcx>,
state_args: GenericArgsRef<'tcx>,
@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> {
is_return: bool,
statements: &mut Vec<Statement<'tcx>>,
) {
let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
(true, false) => 1, // CoroutineState::Complete
(false, false) => 0, // CoroutineState::Yielded
(true, true) => 0, // Poll::Ready
(false, true) => 1, // Poll::Pending
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
});
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
// `Poll::Pending`
if self.is_async_kind && idx == VariantIdx::new(1) {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
match self.coroutine_kind {
// `Poll::Pending`
CoroutineKind::Async(_) => {
if !is_return {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
// FIXME(swatinem): assert that `val` is indeed unit?
statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
))),
source_info,
});
return;
// FIXME(swatinem): assert that `val` is indeed unit?
statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
))),
source_info,
});
return;
}
}
// `Option::None`
CoroutineKind::Gen(_) => {
if is_return {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
))),
source_info,
});
return;
}
}
CoroutineKind::Coroutine => {}
}
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
statements.push(Statement {
@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
};
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
let (state_adt_ref, state_args) = if is_async_kind {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
(poll_adt_ref, poll_args)
} else {
// Compute CoroutineState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
(state_adt_ref, state_args)
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
CoroutineKind::Async(_) => {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
(poll_adt_ref, poll_args)
}
CoroutineKind::Gen(_) => {
// Compute Option<yield_ty>
let option_did = tcx.require_lang_item(LangItem::Option, None);
let option_adt_ref = tcx.adt_def(option_did);
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
(option_adt_ref, option_args)
}
CoroutineKind::Coroutine => {
// Compute CoroutineState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
(state_adt_ref, state_args)
}
};
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
let mut transform = TransformVisitor {
tcx,
is_async_kind,
coroutine_kind: body.coroutine_kind().unwrap(),
state_adt_ref,
state_args,
remap,