Auto merge of #117805 - estebank:arg-fn-mismatch, r=petrochenkov
On Fn arg mismatch for a fn path, suggest a closure When encountering a fn call that has a path to another fn being passed in, where an `Fn` impl is expected, and the arguments differ, suggest wrapping the argument with a closure with the appropriate arguments. The last `help` is new: ``` error[E0631]: type mismatch in function arguments --> $DIR/E0631.rs:9:9 | LL | fn f(_: u64) {} | ------------ found signature defined here ... LL | foo(f); | --- ^ expected due to this | | | required by a bound introduced by this call | = note: expected function signature `fn(usize) -> _` found function signature `fn(u64) -> _` note: required by a bound in `foo` --> $DIR/E0631.rs:3:11 | LL | fn foo<F: Fn(usize)>(_: F) {} | ^^^^^^^^^ required by this bound in `foo` help: consider wrapping the function in a closure | LL | foo(|arg0: usize| f(/* u64 */)); | +++++++++++++ +++++++++++ ```
This commit is contained in:
commit
87e1447aad
15 changed files with 253 additions and 61 deletions
|
@ -17,7 +17,7 @@ use rustc_errors::{
|
|||
ErrorGuaranteed, MultiSpan, Style, SuggestionStyle,
|
||||
};
|
||||
use rustc_hir as hir;
|
||||
use rustc_hir::def::DefKind;
|
||||
use rustc_hir::def::{DefKind, Res};
|
||||
use rustc_hir::def_id::DefId;
|
||||
use rustc_hir::intravisit::Visitor;
|
||||
use rustc_hir::is_range_literal;
|
||||
|
@ -36,7 +36,7 @@ use rustc_middle::ty::{
|
|||
TypeSuperFoldable, TypeVisitableExt, TypeckResults,
|
||||
};
|
||||
use rustc_span::def_id::LocalDefId;
|
||||
use rustc_span::symbol::{sym, Ident, Symbol};
|
||||
use rustc_span::symbol::{kw, sym, Ident, Symbol};
|
||||
use rustc_span::{BytePos, DesugaringKind, ExpnKind, MacroKind, Span, DUMMY_SP};
|
||||
use rustc_target::spec::abi;
|
||||
use std::borrow::Cow;
|
||||
|
@ -222,6 +222,15 @@ pub trait TypeErrCtxtExt<'tcx> {
|
|||
param_env: ty::ParamEnv<'tcx>,
|
||||
) -> DiagnosticBuilder<'tcx, ErrorGuaranteed>;
|
||||
|
||||
fn note_conflicting_fn_args(
|
||||
&self,
|
||||
err: &mut Diagnostic,
|
||||
cause: &ObligationCauseCode<'tcx>,
|
||||
expected: Ty<'tcx>,
|
||||
found: Ty<'tcx>,
|
||||
param_env: ty::ParamEnv<'tcx>,
|
||||
);
|
||||
|
||||
fn note_conflicting_closure_bounds(
|
||||
&self,
|
||||
cause: &ObligationCauseCode<'tcx>,
|
||||
|
@ -1034,7 +1043,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind else {
|
||||
return;
|
||||
};
|
||||
let hir::def::Res::Local(hir_id) = path.res else {
|
||||
let Res::Local(hir_id) = path.res else {
|
||||
return;
|
||||
};
|
||||
let Some(hir::Node::Pat(pat)) = self.tcx.hir().find(hir_id) else {
|
||||
|
@ -1618,7 +1627,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
}
|
||||
}
|
||||
if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
|
||||
&& let hir::def::Res::Local(hir_id) = path.res
|
||||
&& let Res::Local(hir_id) = path.res
|
||||
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(hir_id)
|
||||
&& let Some(hir::Node::Local(local)) = self.tcx.hir().find_parent(binding.hir_id)
|
||||
&& let None = local.ty
|
||||
|
@ -2005,6 +2014,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
let signature_kind = format!("{argument_kind} signature");
|
||||
err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);
|
||||
|
||||
self.note_conflicting_fn_args(&mut err, cause, expected, found, param_env);
|
||||
self.note_conflicting_closure_bounds(cause, &mut err);
|
||||
|
||||
if let Some(found_node) = found_node {
|
||||
|
@ -2014,6 +2024,151 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
err
|
||||
}
|
||||
|
||||
fn note_conflicting_fn_args(
|
||||
&self,
|
||||
err: &mut Diagnostic,
|
||||
cause: &ObligationCauseCode<'tcx>,
|
||||
expected: Ty<'tcx>,
|
||||
found: Ty<'tcx>,
|
||||
param_env: ty::ParamEnv<'tcx>,
|
||||
) {
|
||||
let ObligationCauseCode::FunctionArgumentObligation { arg_hir_id, .. } = cause else {
|
||||
return;
|
||||
};
|
||||
let ty::FnPtr(expected) = expected.kind() else {
|
||||
return;
|
||||
};
|
||||
let ty::FnPtr(found) = found.kind() else {
|
||||
return;
|
||||
};
|
||||
let Some(Node::Expr(arg)) = self.tcx.hir().find(*arg_hir_id) else {
|
||||
return;
|
||||
};
|
||||
let hir::ExprKind::Path(path) = arg.kind else {
|
||||
return;
|
||||
};
|
||||
let expected_inputs = self.tcx.instantiate_bound_regions_with_erased(*expected).inputs();
|
||||
let found_inputs = self.tcx.instantiate_bound_regions_with_erased(*found).inputs();
|
||||
let both_tys = expected_inputs.iter().copied().zip(found_inputs.iter().copied());
|
||||
|
||||
let arg_expr = |infcx: &InferCtxt<'tcx>, name, expected: Ty<'tcx>, found: Ty<'tcx>| {
|
||||
let (expected_ty, expected_refs) = get_deref_type_and_refs(expected);
|
||||
let (found_ty, found_refs) = get_deref_type_and_refs(found);
|
||||
|
||||
if infcx.can_eq(param_env, found_ty, expected_ty) {
|
||||
if found_refs.len() == expected_refs.len()
|
||||
&& found_refs.iter().eq(expected_refs.iter())
|
||||
{
|
||||
name
|
||||
} else if found_refs.len() > expected_refs.len() {
|
||||
let refs = &found_refs[..found_refs.len() - expected_refs.len()];
|
||||
if found_refs[..expected_refs.len()].iter().eq(expected_refs.iter()) {
|
||||
format!(
|
||||
"{}{name}",
|
||||
refs.iter()
|
||||
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
)
|
||||
} else {
|
||||
// The refs have different mutability.
|
||||
format!(
|
||||
"{}*{name}",
|
||||
refs.iter()
|
||||
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
)
|
||||
}
|
||||
} else if expected_refs.len() > found_refs.len() {
|
||||
format!(
|
||||
"{}{name}",
|
||||
(0..(expected_refs.len() - found_refs.len()))
|
||||
.map(|_| "*")
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}{name}",
|
||||
found_refs
|
||||
.iter()
|
||||
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
|
||||
.chain(found_refs.iter().map(|_| "*".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
format!("/* {found} */")
|
||||
}
|
||||
};
|
||||
let args_have_same_underlying_type = both_tys.clone().all(|(expected, found)| {
|
||||
let (expected_ty, _) = get_deref_type_and_refs(expected);
|
||||
let (found_ty, _) = get_deref_type_and_refs(found);
|
||||
self.can_eq(param_env, found_ty, expected_ty)
|
||||
});
|
||||
let (closure_names, call_names): (Vec<_>, Vec<_>) = if args_have_same_underlying_type
|
||||
&& !expected_inputs.is_empty()
|
||||
&& expected_inputs.len() == found_inputs.len()
|
||||
&& let Some(typeck) = &self.typeck_results
|
||||
&& let Res::Def(_, fn_def_id) = typeck.qpath_res(&path, *arg_hir_id)
|
||||
{
|
||||
let closure: Vec<_> = self
|
||||
.tcx
|
||||
.fn_arg_names(fn_def_id)
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, ident)| {
|
||||
if ident.name.is_empty() || ident.name == kw::SelfLower {
|
||||
format!("arg{i}")
|
||||
} else {
|
||||
format!("{ident}")
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let args = closure
|
||||
.iter()
|
||||
.zip(both_tys)
|
||||
.map(|(name, (expected, found))| {
|
||||
arg_expr(self.infcx, name.to_owned(), expected, found)
|
||||
})
|
||||
.collect();
|
||||
(closure, args)
|
||||
} else {
|
||||
let closure_args = expected_inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("arg{i}"))
|
||||
.collect::<Vec<_>>();
|
||||
let call_args = both_tys
|
||||
.enumerate()
|
||||
.map(|(i, (expected, found))| {
|
||||
arg_expr(self.infcx, format!("arg{i}"), expected, found)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(closure_args, call_args)
|
||||
};
|
||||
let closure_names: Vec<_> = closure_names
|
||||
.into_iter()
|
||||
.zip(expected_inputs.iter())
|
||||
.map(|(name, ty)| {
|
||||
format!(
|
||||
"{name}{}",
|
||||
if ty.has_infer_types() { String::new() } else { format!(": {ty}") }
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
err.multipart_suggestion(
|
||||
format!("consider wrapping the function in a closure"),
|
||||
vec![
|
||||
(arg.span.shrink_to_lo(), format!("|{}| ", closure_names.join(", "))),
|
||||
(arg.span.shrink_to_hi(), format!("({})", call_names.join(", "))),
|
||||
],
|
||||
Applicability::MaybeIncorrect,
|
||||
);
|
||||
}
|
||||
|
||||
// Add a note if there are two `Fn`-family bounds that have conflicting argument
|
||||
// requirements, which will always cause a closure to have a type error.
|
||||
fn note_conflicting_closure_bounds(
|
||||
|
@ -3634,7 +3789,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
}
|
||||
}
|
||||
if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
|
||||
&& let hir::Path { res: hir::def::Res::Local(hir_id), .. } = path
|
||||
&& let hir::Path { res: Res::Local(hir_id), .. } = path
|
||||
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(*hir_id)
|
||||
&& let parent_hir_id = self.tcx.hir().parent_id(binding.hir_id)
|
||||
&& let Some(hir::Node::Local(local)) = self.tcx.hir().find(parent_hir_id)
|
||||
|
@ -3894,7 +4049,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
|
|||
);
|
||||
|
||||
if let hir::ExprKind::Path(hir::QPath::Resolved(None, path)) = expr.kind
|
||||
&& let hir::Path { res: hir::def::Res::Local(hir_id), .. } = path
|
||||
&& let hir::Path { res: Res::Local(hir_id), .. } = path
|
||||
&& let Some(hir::Node::Pat(binding)) = self.tcx.hir().find(*hir_id)
|
||||
&& let Some(parent) = self.tcx.hir().find_parent(binding.hir_id)
|
||||
{
|
||||
|
@ -4349,17 +4504,6 @@ fn hint_missing_borrow<'tcx>(
|
|||
|
||||
let args = fn_decl.inputs.iter();
|
||||
|
||||
fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
|
||||
let mut refs = vec![];
|
||||
|
||||
while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
|
||||
ty = *new_ty;
|
||||
refs.push(*mutbl);
|
||||
}
|
||||
|
||||
(ty, refs)
|
||||
}
|
||||
|
||||
let mut to_borrow = Vec::new();
|
||||
let mut remove_borrow = Vec::new();
|
||||
|
||||
|
@ -4519,7 +4663,7 @@ impl<'a, 'hir> hir::intravisit::Visitor<'hir> for ReplaceImplTraitVisitor<'a> {
|
|||
fn visit_ty(&mut self, t: &'hir hir::Ty<'hir>) {
|
||||
if let hir::TyKind::Path(hir::QPath::Resolved(
|
||||
None,
|
||||
hir::Path { res: hir::def::Res::Def(_, segment_did), .. },
|
||||
hir::Path { res: Res::Def(_, segment_did), .. },
|
||||
)) = t.kind
|
||||
{
|
||||
if self.param_did == *segment_did {
|
||||
|
@ -4652,3 +4796,14 @@ pub fn suggest_desugaring_async_fn_to_impl_future_in_trait<'tcx>(
|
|||
|
||||
Some(sugg)
|
||||
}
|
||||
|
||||
fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
|
||||
let mut refs = vec![];
|
||||
|
||||
while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
|
||||
ty = *new_ty;
|
||||
refs.push(*mutbl);
|
||||
}
|
||||
|
||||
(ty, refs)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue