From 4b6e1d1c5fc9eaedc119759d99b6105b4f7f5d1e Mon Sep 17 00:00:00 2001 From: Maybe Waffle Date: Tue, 22 Nov 2022 18:31:23 +0000 Subject: [PATCH] Add `TyCtxt::is_fn_trait` --- compiler/rustc_hir_typeck/src/closure.rs | 2 +- compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs | 2 +- compiler/rustc_middle/src/middle/lang_items.rs | 8 ++++++++ .../src/traits/error_reporting/mod.rs | 7 +++---- .../src/traits/error_reporting/suggestions.rs | 6 ++---- .../src/traits/select/candidate_assembly.rs | 2 +- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index 11a5f2b0d90..e584d413c41 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -263,7 +263,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let trait_def_id = projection.trait_def_id(tcx); - let is_fn = tcx.fn_trait_kind_from_def_id(trait_def_id).is_some(); + let is_fn = tcx.is_fn_trait(trait_def_id); let gen_trait = tcx.require_lang_item(LangItem::Generator, cause_span); let is_gen = gen_trait == trait_def_id; if !is_fn && !is_gen { diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index 5c7d21fc2e4..86384c7b93e 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -2115,7 +2115,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { { if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = predicate.kind().skip_binder() && pred.self_ty().peel_refs() == callee_ty - && self.tcx.fn_trait_kind_from_def_id(pred.def_id()).is_some() + && self.tcx.is_fn_trait(pred.def_id()) { err.span_note(span, "callable defined here"); return; diff --git a/compiler/rustc_middle/src/middle/lang_items.rs b/compiler/rustc_middle/src/middle/lang_items.rs index e1bcb74b281..343ea1f00f5 100644 --- a/compiler/rustc_middle/src/middle/lang_items.rs +++ b/compiler/rustc_middle/src/middle/lang_items.rs @@ -27,6 +27,9 @@ impl<'tcx> TyCtxt<'tcx> { }) } + /// Given a [`DefId`] of a [`Fn`], [`FnMut`] or [`FnOnce`] traits, + /// returns a corresponding [`ty::ClosureKind`]. + /// For any other [`DefId`] return `None`. pub fn fn_trait_kind_from_def_id(self, id: DefId) -> Option { let items = self.lang_items(); match Some(id) { @@ -36,6 +39,11 @@ impl<'tcx> TyCtxt<'tcx> { _ => None, } } + + /// Returns `true` if `id` is a `DefId` of [`Fn`], [`FnMut`] or [`FnOnce`] traits. + pub fn is_fn_trait(self, id: DefId) -> bool { + self.fn_trait_kind_from_def_id(id).is_some() + } } /// Returns `true` if the specified `lang_item` must be present for this diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index de35fbfe6d0..e96b9b64e78 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -687,7 +687,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { } ObligationCauseCode::BindingObligation(def_id, _) | ObligationCauseCode::ItemObligation(def_id) - if tcx.fn_trait_kind_from_def_id(*def_id).is_some() => + if tcx.is_fn_trait(*def_id) => { err.code(rustc_errors::error_code!(E0059)); err.set_primary_message(format!( @@ -847,8 +847,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { ); } - let is_fn_trait = - tcx.fn_trait_kind_from_def_id(trait_ref.def_id()).is_some(); + let is_fn_trait = tcx.is_fn_trait(trait_ref.def_id()); let is_target_feature_fn = if let ty::FnDef(def_id, _) = *trait_ref.skip_binder().self_ty().kind() { @@ -2156,7 +2155,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> { if generics.params.iter().any(|p| p.name != kw::SelfUpper) && !snippet.ends_with('>') && !generics.has_impl_trait() - && !self.tcx.fn_trait_kind_from_def_id(def_id).is_some() + && !self.tcx.is_fn_trait(def_id) { // FIXME: To avoid spurious suggestions in functions where type arguments // where already supplied, we check the snippet to make sure it doesn't diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index 4ae9b22f221..ddffe25bc26 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -1679,9 +1679,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { ) -> Ty<'tcx> { let inputs = trait_ref.skip_binder().substs.type_at(1); let sig = match inputs.kind() { - ty::Tuple(inputs) - if infcx.tcx.fn_trait_kind_from_def_id(trait_ref.def_id()).is_some() => - { + ty::Tuple(inputs) if infcx.tcx.is_fn_trait(trait_ref.def_id()) => { infcx.tcx.mk_fn_sig( inputs.iter(), infcx.next_ty_var(TypeVariableOrigin { @@ -1752,7 +1750,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { && let predicates = self.tcx.predicates_of(def_id).instantiate_identity(self.tcx) && let Some(pred) = predicates.predicates.get(*idx) && let ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred)) = pred.kind().skip_binder() - && self.tcx.fn_trait_kind_from_def_id(trait_pred.def_id()).is_some() + && self.tcx.is_fn_trait(trait_pred.def_id()) { let expected_self = self.tcx.anonymize_late_bound_regions(pred.kind().rebind(trait_pred.self_ty())); diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 460ca923c6e..adce549a858 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -489,7 +489,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { candidates: &mut SelectionCandidateSet<'tcx>, ) { // We provide impl of all fn traits for fn pointers. - if self.tcx().fn_trait_kind_from_def_id(obligation.predicate.def_id()).is_none() { + if !self.tcx().is_fn_trait(obligation.predicate.def_id()) { return; }