1
Fork 0

On Fn arg mismatch for a fn path, suggest a closure

When encountering a fn call that has a path to another fn being passed
in, where an `Fn` impl is expected, and the arguments differ, suggest
wrapping the argument with a closure with the appropriate arguments.
This commit is contained in:
Esteban Küber 2023-11-11 04:32:32 +00:00
parent abe34e9ab1
commit dfe32b6a43
15 changed files with 247 additions and 54 deletions

View file

@ -222,6 +222,15 @@ pub trait TypeErrCtxtExt<'tcx> {
param_env: ty::ParamEnv<'tcx>,
) -> DiagnosticBuilder<'tcx, ErrorGuaranteed>;
fn note_conflicting_fn_args(
&self,
err: &mut Diagnostic,
cause: &ObligationCauseCode<'tcx>,
expected: Ty<'tcx>,
found: Ty<'tcx>,
param_env: ty::ParamEnv<'tcx>,
);
fn note_conflicting_closure_bounds(
&self,
cause: &ObligationCauseCode<'tcx>,
@ -2005,6 +2014,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
let signature_kind = format!("{argument_kind} signature");
err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);
self.note_conflicting_fn_args(&mut err, cause, expected, found, param_env);
self.note_conflicting_closure_bounds(cause, &mut err);
if let Some(found_node) = found_node {
@ -2014,6 +2024,152 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
err
}
fn note_conflicting_fn_args(
&self,
err: &mut Diagnostic,
cause: &ObligationCauseCode<'tcx>,
expected: Ty<'tcx>,
found: Ty<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) {
let ObligationCauseCode::FunctionArgumentObligation { arg_hir_id, .. } = cause else {
return;
};
let ty::FnPtr(expected) = expected.kind() else {
return;
};
let ty::FnPtr(found) = found.kind() else {
return;
};
let Some(Node::Expr(arg)) = self.tcx.hir().find(*arg_hir_id) else {
return;
};
let hir::ExprKind::Path(path) = arg.kind else {
return;
};
let expected_inputs = self.tcx.erase_late_bound_regions(*expected).inputs();
let found_inputs = self.tcx.erase_late_bound_regions(*found).inputs();
let both_tys = expected_inputs.iter().cloned().zip(found_inputs.iter().cloned());
let arg_expr = |infcx: &InferCtxt<'tcx>, name, expected: Ty<'tcx>, found: Ty<'tcx>| {
let (expected_ty, expected_refs) = get_deref_type_and_refs(expected);
let (found_ty, found_refs) = get_deref_type_and_refs(found);
if infcx.can_eq(param_env, found_ty, expected_ty) {
if found_refs.len() == expected_refs.len()
&& found_refs.iter().zip(expected_refs.iter()).all(|(e, f)| e == f)
{
name
} else if found_refs.len() > expected_refs.len() {
if found_refs[..found_refs.len() - expected_refs.len()]
.iter()
.zip(expected_refs.iter())
.any(|(e, f)| e != f)
{
// The refs have different mutability.
format!(
"{}*{name}",
found_refs[..found_refs.len() - expected_refs.len()]
.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.collect::<Vec<_>>()
.join(""),
)
} else {
format!(
"{}{name}",
found_refs[..found_refs.len() - expected_refs.len()]
.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.collect::<Vec<_>>()
.join(""),
)
}
} else if expected_refs.len() > found_refs.len() {
format!(
"{}{name}",
(0..(expected_refs.len() - found_refs.len()))
.map(|_| "*")
.collect::<Vec<_>>()
.join(""),
)
} else {
format!(
"{}{name}",
found_refs
.iter()
.map(|mutbl| format!("&{}", mutbl.prefix_str()))
.chain(found_refs.iter().map(|_| "*".to_string()))
.collect::<Vec<_>>()
.join(""),
)
}
} else {
format!("/* {found} */")
}
};
let (closure_names, call_names): (Vec<_>, Vec<_>) =
if both_tys.clone().all(|(expected, found)| {
let (expected_ty, _) = get_deref_type_and_refs(expected);
let (found_ty, _) = get_deref_type_and_refs(found);
self.can_eq(param_env, found_ty, expected_ty)
}) && !expected_inputs.is_empty()
&& expected_inputs.len() == found_inputs.len()
&& let hir::QPath::Resolved(_, path) = path
&& let hir::def::Res::Def(_, fn_def_id) = path.res
&& let Some(node) = self.tcx.hir().get_if_local(fn_def_id)
&& let Some(body_id) = node.body_id()
{
let closure = self
.tcx
.hir()
.body_param_names(body_id)
.map(|name| format!("{name}"))
.collect();
let args = self
.tcx
.hir()
.body_param_names(body_id)
.zip(both_tys)
.map(|(name, (expected, found))| {
arg_expr(self.infcx, format!("{name}"), expected, found)
})
.collect();
(closure, args)
} else {
let closure_args = expected_inputs
.iter()
.enumerate()
.map(|(i, _)| format!("arg{i}"))
.collect::<Vec<_>>();
let call_args = both_tys
.enumerate()
.map(|(i, (expected, found))| {
arg_expr(self.infcx, format!("arg{i}"), expected, found)
})
.collect::<Vec<_>>();
(closure_args, call_args)
};
let closure_names: Vec<_> = closure_names
.into_iter()
.zip(expected_inputs.iter())
.map(|(name, ty)| {
format!(
"{name}{}",
if ty.has_infer_types() { String::new() } else { format!(": {ty}") }
)
})
.collect();
err.multipart_suggestion(
format!("consider wrapping the function in a closure"),
vec![
(arg.span.shrink_to_lo(), format!("|{}| ", closure_names.join(", "))),
(arg.span.shrink_to_hi(), format!("({})", call_names.join(", "))),
],
Applicability::MaybeIncorrect,
);
}
// Add a note if there are two `Fn`-family bounds that have conflicting argument
// requirements, which will always cause a closure to have a type error.
fn note_conflicting_closure_bounds(
@ -4349,17 +4505,6 @@ fn hint_missing_borrow<'tcx>(
let args = fn_decl.inputs.iter();
fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
let mut refs = vec![];
while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
ty = *new_ty;
refs.push(*mutbl);
}
(ty, refs)
}
let mut to_borrow = Vec::new();
let mut remove_borrow = Vec::new();
@ -4652,3 +4797,14 @@ pub fn suggest_desugaring_async_fn_to_impl_future_in_trait<'tcx>(
Some(sugg)
}
fn get_deref_type_and_refs(mut ty: Ty<'_>) -> (Ty<'_>, Vec<hir::Mutability>) {
let mut refs = vec![];
while let ty::Ref(_, new_ty, mutbl) = ty.kind() {
ty = *new_ty;
refs.push(*mutbl);
}
(ty, refs)
}