Basic generators work
This commit is contained in:
parent
998a816106
commit
c892b28c02
4 changed files with 98 additions and 34 deletions
|
@ -149,13 +149,14 @@ impl<O> AssertKind<O> {
|
||||||
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
|
ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
|
||||||
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
|
ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
|
||||||
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
|
ResumedAfterReturn(CoroutineKind::Gen(_)) => {
|
||||||
bug!("`gen fn` should just keep returning `None` after the first time")
|
"`gen fn` should just keep returning `None` after completion"
|
||||||
}
|
}
|
||||||
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
|
ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
|
||||||
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
|
ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
|
||||||
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
|
ResumedAfterPanic(CoroutineKind::Gen(_)) => {
|
||||||
bug!("`gen fn` should just keep returning `None` after panicking")
|
"`gen fn` should just keep returning `None` after panicking"
|
||||||
}
|
}
|
||||||
|
|
||||||
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
|
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
|
||||||
bug!("Unexpected AssertKind")
|
bug!("Unexpected AssertKind")
|
||||||
}
|
}
|
||||||
|
|
|
@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
|
||||||
|
|
||||||
struct TransformVisitor<'tcx> {
|
struct TransformVisitor<'tcx> {
|
||||||
tcx: TyCtxt<'tcx>,
|
tcx: TyCtxt<'tcx>,
|
||||||
is_async_kind: bool,
|
coroutine_kind: hir::CoroutineKind,
|
||||||
state_adt_ref: AdtDef<'tcx>,
|
state_adt_ref: AdtDef<'tcx>,
|
||||||
state_args: GenericArgsRef<'tcx>,
|
state_args: GenericArgsRef<'tcx>,
|
||||||
|
|
||||||
|
@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> {
|
||||||
is_return: bool,
|
is_return: bool,
|
||||||
statements: &mut Vec<Statement<'tcx>>,
|
statements: &mut Vec<Statement<'tcx>>,
|
||||||
) {
|
) {
|
||||||
let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
|
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
|
||||||
(true, false) => 1, // CoroutineState::Complete
|
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
|
||||||
(false, false) => 0, // CoroutineState::Yielded
|
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
|
||||||
(true, true) => 0, // Poll::Ready
|
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
|
||||||
(false, true) => 1, // Poll::Pending
|
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
|
||||||
|
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
|
||||||
|
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
|
||||||
});
|
});
|
||||||
|
|
||||||
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
|
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
|
||||||
|
|
||||||
// `Poll::Pending`
|
match self.coroutine_kind {
|
||||||
if self.is_async_kind && idx == VariantIdx::new(1) {
|
// `Poll::Pending`
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
CoroutineKind::Async(_) => {
|
||||||
|
if !is_return {
|
||||||
|
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||||
|
|
||||||
// FIXME(swatinem): assert that `val` is indeed unit?
|
// FIXME(swatinem): assert that `val` is indeed unit?
|
||||||
statements.push(Statement {
|
statements.push(Statement {
|
||||||
kind: StatementKind::Assign(Box::new((
|
kind: StatementKind::Assign(Box::new((
|
||||||
Place::return_place(),
|
Place::return_place(),
|
||||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||||
))),
|
))),
|
||||||
source_info,
|
source_info,
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// `Option::None`
|
||||||
|
CoroutineKind::Gen(_) => {
|
||||||
|
if is_return {
|
||||||
|
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||||
|
|
||||||
|
statements.push(Statement {
|
||||||
|
kind: StatementKind::Assign(Box::new((
|
||||||
|
Place::return_place(),
|
||||||
|
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||||
|
))),
|
||||||
|
source_info,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CoroutineKind::Coroutine => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
|
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
|
||||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
|
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
|
||||||
|
|
||||||
statements.push(Statement {
|
statements.push(Statement {
|
||||||
|
@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
};
|
};
|
||||||
|
|
||||||
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
||||||
let (state_adt_ref, state_args) = if is_async_kind {
|
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
|
||||||
// Compute Poll<return_ty>
|
CoroutineKind::Async(_) => {
|
||||||
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
// Compute Poll<return_ty>
|
||||||
let poll_adt_ref = tcx.adt_def(poll_did);
|
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
||||||
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
|
let poll_adt_ref = tcx.adt_def(poll_did);
|
||||||
(poll_adt_ref, poll_args)
|
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
|
||||||
} else {
|
(poll_adt_ref, poll_args)
|
||||||
// Compute CoroutineState<yield_ty, return_ty>
|
}
|
||||||
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
CoroutineKind::Gen(_) => {
|
||||||
let state_adt_ref = tcx.adt_def(state_did);
|
// Compute Option<yield_ty>
|
||||||
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
|
let option_did = tcx.require_lang_item(LangItem::Option, None);
|
||||||
(state_adt_ref, state_args)
|
let option_adt_ref = tcx.adt_def(option_did);
|
||||||
|
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
|
||||||
|
(option_adt_ref, option_args)
|
||||||
|
}
|
||||||
|
CoroutineKind::Coroutine => {
|
||||||
|
// Compute CoroutineState<yield_ty, return_ty>
|
||||||
|
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
||||||
|
let state_adt_ref = tcx.adt_def(state_did);
|
||||||
|
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
|
||||||
|
(state_adt_ref, state_args)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
|
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
|
||||||
|
|
||||||
|
@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
||||||
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
|
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
|
||||||
let mut transform = TransformVisitor {
|
let mut transform = TransformVisitor {
|
||||||
tcx,
|
tcx,
|
||||||
is_async_kind,
|
coroutine_kind: body.coroutine_kind().unwrap(),
|
||||||
state_adt_ref,
|
state_adt_ref,
|
||||||
state_args,
|
state_args,
|
||||||
remap,
|
remap,
|
||||||
|
|
|
@ -258,6 +258,19 @@ fn resolve_associated_item<'tcx>(
|
||||||
debug_assert!(tcx.defaultness(trait_item_id).has_value());
|
debug_assert!(tcx.defaultness(trait_item_id).has_value());
|
||||||
Some(Instance::new(trait_item_id, rcvr_args))
|
Some(Instance::new(trait_item_id, rcvr_args))
|
||||||
}
|
}
|
||||||
|
} else if Some(trait_ref.def_id) == lang_items.iterator_trait() {
|
||||||
|
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
|
||||||
|
bug!()
|
||||||
|
};
|
||||||
|
if Some(trait_item_id) == tcx.lang_items().next_fn() {
|
||||||
|
// `Iterator::next` is generated by the compiler.
|
||||||
|
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
|
||||||
|
} else {
|
||||||
|
// All other methods are default methods of the `Iterator` trait.
|
||||||
|
// (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`)
|
||||||
|
debug_assert!(tcx.defaultness(trait_item_id).has_value());
|
||||||
|
Some(Instance::new(trait_item_id, rcvr_args))
|
||||||
|
}
|
||||||
} else if Some(trait_ref.def_id) == lang_items.gen_trait() {
|
} else if Some(trait_ref.def_id) == lang_items.gen_trait() {
|
||||||
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
|
let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else {
|
||||||
bug!()
|
bug!()
|
||||||
|
|
18
tests/ui/coroutine/gen_block_iterate.rs
Normal file
18
tests/ui/coroutine/gen_block_iterate.rs
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
// revisions: next old
|
||||||
|
//compile-flags: --edition 2024 -Zunstable-options
|
||||||
|
//[next] compile-flags: -Ztrait-solver=next
|
||||||
|
// run-pass
|
||||||
|
#![feature(coroutines)]
|
||||||
|
|
||||||
|
fn foo() -> impl Iterator<Item = u32> {
|
||||||
|
gen { yield 42; for x in 3..6 { yield x } }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let mut iter = foo();
|
||||||
|
assert_eq!(iter.next(), Some(42));
|
||||||
|
assert_eq!(iter.next(), Some(3));
|
||||||
|
assert_eq!(iter.next(), Some(4));
|
||||||
|
assert_eq!(iter.next(), Some(5));
|
||||||
|
assert_eq!(iter.next(), None);
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue