Auto merge of #116447 - oli-obk:gen_fn, r=compiler-errors
Implement `gen` blocks in the 2024 edition Coroutines tracking issue https://github.com/rust-lang/rust/issues/43122 `gen` block tracking issue https://github.com/rust-lang/rust/issues/117078 This PR implements `gen` blocks that implement `Iterator`. Most of the logic with `async` blocks is shared, and thus I renamed various types that were referring to `async` specifically. An example usage of `gen` blocks is ```rust fn foo() -> impl Iterator<Item = i32> { gen { yield 42; for i in 5..18 { if i.is_even() { continue } yield i * 2; } } } ``` The limitations (to be resolved) of the implementation are listed in the tracking issue
This commit is contained in:
commit
2cad938a81
75 changed files with 1096 additions and 148 deletions
|
@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
|
|||
|
||||
struct TransformVisitor<'tcx> {
|
||||
tcx: TyCtxt<'tcx>,
|
||||
is_async_kind: bool,
|
||||
coroutine_kind: hir::CoroutineKind,
|
||||
state_adt_ref: AdtDef<'tcx>,
|
||||
state_args: GenericArgsRef<'tcx>,
|
||||
|
||||
|
@ -249,6 +249,47 @@ struct TransformVisitor<'tcx> {
|
|||
}
|
||||
|
||||
impl<'tcx> TransformVisitor<'tcx> {
|
||||
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
|
||||
let block = BasicBlock::new(body.basic_blocks.len());
|
||||
|
||||
let source_info = SourceInfo::outermost(body.span);
|
||||
|
||||
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
|
||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||
let statements = vec![Statement {
|
||||
kind: StatementKind::Assign(Box::new((
|
||||
Place::return_place(),
|
||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||
))),
|
||||
source_info,
|
||||
}];
|
||||
|
||||
body.basic_blocks_mut().push(BasicBlockData {
|
||||
statements,
|
||||
terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
|
||||
is_cleanup: false,
|
||||
});
|
||||
|
||||
block
|
||||
}
|
||||
|
||||
fn coroutine_state_adt_and_variant_idx(
|
||||
&self,
|
||||
is_return: bool,
|
||||
) -> (AggregateKind<'tcx>, VariantIdx) {
|
||||
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
|
||||
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
|
||||
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
|
||||
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
|
||||
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
|
||||
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
|
||||
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
|
||||
});
|
||||
|
||||
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
|
||||
(kind, idx)
|
||||
}
|
||||
|
||||
// Make a `CoroutineState` or `Poll` variant assignment.
|
||||
//
|
||||
// `core::ops::CoroutineState` only has single element tuple variants,
|
||||
|
@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> {
|
|||
is_return: bool,
|
||||
statements: &mut Vec<Statement<'tcx>>,
|
||||
) {
|
||||
let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
|
||||
(true, false) => 1, // CoroutineState::Complete
|
||||
(false, false) => 0, // CoroutineState::Yielded
|
||||
(true, true) => 0, // Poll::Ready
|
||||
(false, true) => 1, // Poll::Pending
|
||||
});
|
||||
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
|
||||
|
||||
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
|
||||
match self.coroutine_kind {
|
||||
// `Poll::Pending`
|
||||
CoroutineKind::Async(_) => {
|
||||
if !is_return {
|
||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||
|
||||
// `Poll::Pending`
|
||||
if self.is_async_kind && idx == VariantIdx::new(1) {
|
||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||
// FIXME(swatinem): assert that `val` is indeed unit?
|
||||
statements.push(Statement {
|
||||
kind: StatementKind::Assign(Box::new((
|
||||
Place::return_place(),
|
||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||
))),
|
||||
source_info,
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
// `Option::None`
|
||||
CoroutineKind::Gen(_) => {
|
||||
if is_return {
|
||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
|
||||
|
||||
// FIXME(swatinem): assert that `val` is indeed unit?
|
||||
statements.push(Statement {
|
||||
kind: StatementKind::Assign(Box::new((
|
||||
Place::return_place(),
|
||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||
))),
|
||||
source_info,
|
||||
});
|
||||
return;
|
||||
statements.push(Statement {
|
||||
kind: StatementKind::Assign(Box::new((
|
||||
Place::return_place(),
|
||||
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
|
||||
))),
|
||||
source_info,
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
CoroutineKind::Coroutine => {}
|
||||
}
|
||||
|
||||
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
|
||||
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
|
||||
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
|
||||
|
||||
statements.push(Statement {
|
||||
|
@ -1263,10 +1317,13 @@ fn create_coroutine_resume_function<'tcx>(
|
|||
}
|
||||
|
||||
if can_return {
|
||||
cases.insert(
|
||||
1,
|
||||
(RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))),
|
||||
);
|
||||
let block = match coroutine_kind {
|
||||
CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
|
||||
insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
|
||||
}
|
||||
CoroutineKind::Gen(_) => transform.insert_none_ret_block(body),
|
||||
};
|
||||
cases.insert(1, (RETURNED, block));
|
||||
}
|
||||
|
||||
insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
|
||||
|
@ -1439,18 +1496,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
|||
};
|
||||
|
||||
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
|
||||
let (state_adt_ref, state_args) = if is_async_kind {
|
||||
// Compute Poll<return_ty>
|
||||
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
||||
let poll_adt_ref = tcx.adt_def(poll_did);
|
||||
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
|
||||
(poll_adt_ref, poll_args)
|
||||
} else {
|
||||
// Compute CoroutineState<yield_ty, return_ty>
|
||||
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
||||
let state_adt_ref = tcx.adt_def(state_did);
|
||||
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
|
||||
(state_adt_ref, state_args)
|
||||
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
|
||||
CoroutineKind::Async(_) => {
|
||||
// Compute Poll<return_ty>
|
||||
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
|
||||
let poll_adt_ref = tcx.adt_def(poll_did);
|
||||
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
|
||||
(poll_adt_ref, poll_args)
|
||||
}
|
||||
CoroutineKind::Gen(_) => {
|
||||
// Compute Option<yield_ty>
|
||||
let option_did = tcx.require_lang_item(LangItem::Option, None);
|
||||
let option_adt_ref = tcx.adt_def(option_did);
|
||||
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
|
||||
(option_adt_ref, option_args)
|
||||
}
|
||||
CoroutineKind::Coroutine => {
|
||||
// Compute CoroutineState<yield_ty, return_ty>
|
||||
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
|
||||
let state_adt_ref = tcx.adt_def(state_did);
|
||||
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
|
||||
(state_adt_ref, state_args)
|
||||
}
|
||||
};
|
||||
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
|
||||
|
||||
|
@ -1518,7 +1585,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
|
|||
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
|
||||
let mut transform = TransformVisitor {
|
||||
tcx,
|
||||
is_async_kind,
|
||||
coroutine_kind: body.coroutine_kind().unwrap(),
|
||||
state_adt_ref,
|
||||
state_args,
|
||||
remap,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue