Coroutine closures implement regular Fn traits, when possible

This commit is contained in:
Michael Goulet 2024-02-05 19:59:05 +00:00
parent 08af64e96b
commit b8c93f1223
5 changed files with 142 additions and 18 deletions

View file

@ -2074,7 +2074,9 @@ fn confirm_select_candidate<'cx, 'tcx>(
} else if lang_items.async_iterator_trait() == Some(trait_def_id) {
confirm_async_iterator_candidate(selcx, obligation, data)
} else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
if obligation.predicate.self_ty().is_closure() {
if obligation.predicate.self_ty().is_closure()
|| obligation.predicate.self_ty().is_coroutine_closure()
{
confirm_closure_candidate(selcx, obligation, data)
} else {
confirm_fn_pointer_candidate(selcx, obligation, data)
@ -2386,11 +2388,75 @@ fn confirm_closure_candidate<'cx, 'tcx>(
obligation: &ProjectionTyObligation<'tcx>,
nested: Vec<PredicateObligation<'tcx>>,
) -> Progress<'tcx> {
let tcx = selcx.tcx();
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
let ty::Closure(_, args) = self_ty.kind() else {
unreachable!("expected closure self type for closure candidate, found {self_ty}")
let closure_sig = match *self_ty.kind() {
ty::Closure(_, args) => args.as_closure().sig(),
// Construct a "normal" `FnOnce` signature for coroutine-closure. This is
// basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but
// I didn't see a good way to unify those.
ty::CoroutineClosure(def_id, args) => {
let args = args.as_coroutine_closure();
let kind_ty = args.kind_ty();
args.coroutine_closure_sig().map_bound(|sig| {
// If we know the kind and upvars, use that directly.
// Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay
// the projection, like the `AsyncFn*` traits do.
let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() {
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(def_id),
ty::ClosureKind::FnOnce,
tcx.lifetimes.re_static,
args.tupled_upvars_ty(),
args.coroutine_captures_by_ref_ty(),
)
} else {
let async_fn_kind_trait_def_id =
tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
let upvars_projection_def_id = tcx
.associated_items(async_fn_kind_trait_def_id)
.filter_by_name_unhygienic(sym::Upvars)
.next()
.unwrap()
.def_id;
let tupled_upvars_ty = Ty::new_projection(
tcx,
upvars_projection_def_id,
[
ty::GenericArg::from(kind_ty),
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(),
tcx.lifetimes.re_static.into(),
sig.tupled_inputs_ty.into(),
args.tupled_upvars_ty().into(),
args.coroutine_captures_by_ref_ty().into(),
],
);
sig.to_coroutine(
tcx,
args.parent_args(),
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
)
};
tcx.mk_fn_sig(
[sig.tupled_inputs_ty],
output_ty,
sig.c_variadic,
sig.unsafety,
sig.abi,
)
})
}
_ => {
unreachable!("expected closure self type for closure candidate, found {self_ty}");
}
};
let closure_sig = args.as_closure().sig();
let Normalized { value: closure_sig, obligations } = normalize_with_depth(
selcx,
obligation.param_env,

View file

@ -332,6 +332,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
}
}
ty::CoroutineClosure(def_id, args) => {
let is_const = self.tcx().is_const_fn_raw(def_id);
match self.infcx.closure_kind(self_ty) {
Some(closure_kind) => {
let no_borrows = self
.infcx
.shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty())
.tuple_fields()
.is_empty();
if no_borrows && closure_kind.extends(kind) {
candidates.vec.push(ClosureCandidate { is_const });
} else if kind == ty::ClosureKind::FnOnce {
candidates.vec.push(ClosureCandidate { is_const });
}
}
None => {
if kind == ty::ClosureKind::FnOnce {
candidates.vec.push(ClosureCandidate { is_const });
} else {
// This stays ambiguous until kind+upvars are determined.
candidates.ambiguous = true;
}
}
}
}
ty::Infer(ty::TyVar(_)) => {
debug!("assemble_unboxed_closure_candidates: ambiguous self-type");
candidates.ambiguous = true;

View file

@ -865,17 +865,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let ty::Closure(closure_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
let trait_ref = match *self_ty.kind() {
ty::Closure(_, args) => {
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
}
ty::CoroutineClosure(_, args) => {
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
self.tcx(),
obligation.predicate.def_id(),
[self_ty, sig.tupled_inputs_ty],
)
})
}
_ => {
bug!("closure candidate for non-closure {:?}", obligation);
}
};
let trait_ref =
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_);
let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
Ok(nested)
self.confirm_poly_trait_refs(obligation, trait_ref)
}
#[instrument(skip(self), level = "debug")]