1
Fork 0

Add TyCtxt::is_fn_trait

This commit is contained in:
Maybe Waffle 2022-11-22 18:31:23 +00:00
parent d0c7ed3bea
commit 4b6e1d1c5f
6 changed files with 16 additions and 11 deletions

View file

@ -263,7 +263,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let trait_def_id = projection.trait_def_id(tcx); let trait_def_id = projection.trait_def_id(tcx);
let is_fn = tcx.fn_trait_kind_from_def_id(trait_def_id).is_some(); let is_fn = tcx.is_fn_trait(trait_def_id);
let gen_trait = tcx.require_lang_item(LangItem::Generator, cause_span); let gen_trait = tcx.require_lang_item(LangItem::Generator, cause_span);
let is_gen = gen_trait == trait_def_id; let is_gen = gen_trait == trait_def_id;
if !is_fn && !is_gen { if !is_fn && !is_gen {

View file

@ -2115,7 +2115,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
{ {
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = predicate.kind().skip_binder() if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = predicate.kind().skip_binder()
&& pred.self_ty().peel_refs() == callee_ty && pred.self_ty().peel_refs() == callee_ty
&& self.tcx.fn_trait_kind_from_def_id(pred.def_id()).is_some() && self.tcx.is_fn_trait(pred.def_id())
{ {
err.span_note(span, "callable defined here"); err.span_note(span, "callable defined here");
return; return;

View file

@ -27,6 +27,9 @@ impl<'tcx> TyCtxt<'tcx> {
}) })
} }
/// Given a [`DefId`] of a [`Fn`], [`FnMut`] or [`FnOnce`] traits,
/// returns a corresponding [`ty::ClosureKind`].
/// For any other [`DefId`] return `None`.
pub fn fn_trait_kind_from_def_id(self, id: DefId) -> Option<ty::ClosureKind> { pub fn fn_trait_kind_from_def_id(self, id: DefId) -> Option<ty::ClosureKind> {
let items = self.lang_items(); let items = self.lang_items();
match Some(id) { match Some(id) {
@ -36,6 +39,11 @@ impl<'tcx> TyCtxt<'tcx> {
_ => None, _ => None,
} }
} }
/// Returns `true` if `id` is a `DefId` of [`Fn`], [`FnMut`] or [`FnOnce`] traits.
pub fn is_fn_trait(self, id: DefId) -> bool {
self.fn_trait_kind_from_def_id(id).is_some()
}
} }
/// Returns `true` if the specified `lang_item` must be present for this /// Returns `true` if the specified `lang_item` must be present for this

View file

@ -687,7 +687,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
} }
ObligationCauseCode::BindingObligation(def_id, _) ObligationCauseCode::BindingObligation(def_id, _)
| ObligationCauseCode::ItemObligation(def_id) | ObligationCauseCode::ItemObligation(def_id)
if tcx.fn_trait_kind_from_def_id(*def_id).is_some() => if tcx.is_fn_trait(*def_id) =>
{ {
err.code(rustc_errors::error_code!(E0059)); err.code(rustc_errors::error_code!(E0059));
err.set_primary_message(format!( err.set_primary_message(format!(
@ -847,8 +847,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
); );
} }
let is_fn_trait = let is_fn_trait = tcx.is_fn_trait(trait_ref.def_id());
tcx.fn_trait_kind_from_def_id(trait_ref.def_id()).is_some();
let is_target_feature_fn = if let ty::FnDef(def_id, _) = let is_target_feature_fn = if let ty::FnDef(def_id, _) =
*trait_ref.skip_binder().self_ty().kind() *trait_ref.skip_binder().self_ty().kind()
{ {
@ -2156,7 +2155,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
if generics.params.iter().any(|p| p.name != kw::SelfUpper) if generics.params.iter().any(|p| p.name != kw::SelfUpper)
&& !snippet.ends_with('>') && !snippet.ends_with('>')
&& !generics.has_impl_trait() && !generics.has_impl_trait()
&& !self.tcx.fn_trait_kind_from_def_id(def_id).is_some() && !self.tcx.is_fn_trait(def_id)
{ {
// FIXME: To avoid spurious suggestions in functions where type arguments // FIXME: To avoid spurious suggestions in functions where type arguments
// where already supplied, we check the snippet to make sure it doesn't // where already supplied, we check the snippet to make sure it doesn't

View file

@ -1679,9 +1679,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
) -> Ty<'tcx> { ) -> Ty<'tcx> {
let inputs = trait_ref.skip_binder().substs.type_at(1); let inputs = trait_ref.skip_binder().substs.type_at(1);
let sig = match inputs.kind() { let sig = match inputs.kind() {
ty::Tuple(inputs) ty::Tuple(inputs) if infcx.tcx.is_fn_trait(trait_ref.def_id()) => {
if infcx.tcx.fn_trait_kind_from_def_id(trait_ref.def_id()).is_some() =>
{
infcx.tcx.mk_fn_sig( infcx.tcx.mk_fn_sig(
inputs.iter(), inputs.iter(),
infcx.next_ty_var(TypeVariableOrigin { infcx.next_ty_var(TypeVariableOrigin {
@ -1752,7 +1750,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
&& let predicates = self.tcx.predicates_of(def_id).instantiate_identity(self.tcx) && let predicates = self.tcx.predicates_of(def_id).instantiate_identity(self.tcx)
&& let Some(pred) = predicates.predicates.get(*idx) && let Some(pred) = predicates.predicates.get(*idx)
&& let ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred)) = pred.kind().skip_binder() && let ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred)) = pred.kind().skip_binder()
&& self.tcx.fn_trait_kind_from_def_id(trait_pred.def_id()).is_some() && self.tcx.is_fn_trait(trait_pred.def_id())
{ {
let expected_self = let expected_self =
self.tcx.anonymize_late_bound_regions(pred.kind().rebind(trait_pred.self_ty())); self.tcx.anonymize_late_bound_regions(pred.kind().rebind(trait_pred.self_ty()));

View file

@ -489,7 +489,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
candidates: &mut SelectionCandidateSet<'tcx>, candidates: &mut SelectionCandidateSet<'tcx>,
) { ) {
// We provide impl of all fn traits for fn pointers. // We provide impl of all fn traits for fn pointers.
if self.tcx().fn_trait_kind_from_def_id(obligation.predicate.def_id()).is_none() { if !self.tcx().is_fn_trait(obligation.predicate.def_id()) {
return; return;
} }