1
Fork 0

Auto merge of #115864 - compiler-errors:rpitit-sugg, r=estebank

Suggest desugaring to return-position `impl Future` when an `async fn` in trait fails an auto trait bound

First commit allows us to store the span of the `async` keyword in HIR.

Second commit implements a suggestion to desugar an `async fn` to a return-position `impl Future` in trait to slightly improve the `Send` situation being discussed in #115822.

This suggestion is only made when `#![feature(return_type_notation)]` is not enabled -- if it is, we should instead suggest an appropriate where-clause bound.
This commit is contained in:
bors 2023-09-21 21:12:32 +00:00
commit b3aa8e7168
31 changed files with 348 additions and 54 deletions

View file

@ -1308,7 +1308,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
fn lower_asyncness(&mut self, a: Async) -> hir::IsAsync {
match a {
Async::Yes { .. } => hir::IsAsync::Async,
Async::Yes { span, .. } => hir::IsAsync::Async(span),
Async::No => hir::IsAsync::NotAsync,
}
}

View file

@ -302,7 +302,7 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
if free_region.bound_region.is_named() {
// A named region that is actually named.
Some(RegionName { name, source: RegionNameSource::NamedFreeRegion(span) })
} else if let hir::IsAsync::Async = tcx.asyncness(self.mir_hir_id().owner) {
} else if tcx.asyncness(self.mir_hir_id().owner).is_async() {
// If we spuriously thought that the region is named, we should let the
// system generate a true name for error messages. Currently this can
// happen if we have an elided name in an async fn for example: the

View file

@ -2853,13 +2853,13 @@ impl ImplicitSelfKind {
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
#[derive(HashStable_Generic)]
pub enum IsAsync {
Async,
Async(Span),
NotAsync,
}
impl IsAsync {
pub fn is_async(self) -> bool {
self == IsAsync::Async
matches!(self, IsAsync::Async(_))
}
}
@ -3296,7 +3296,7 @@ pub struct FnHeader {
impl FnHeader {
pub fn is_async(&self) -> bool {
matches!(&self.asyncness, IsAsync::Async)
matches!(&self.asyncness, IsAsync::Async(_))
}
pub fn is_const(&self) -> bool {
@ -4091,10 +4091,10 @@ mod size_asserts {
static_assert_size!(GenericBound<'_>, 48);
static_assert_size!(Generics<'_>, 56);
static_assert_size!(Impl<'_>, 80);
static_assert_size!(ImplItem<'_>, 80);
static_assert_size!(ImplItemKind<'_>, 32);
static_assert_size!(Item<'_>, 80);
static_assert_size!(ItemKind<'_>, 48);
static_assert_size!(ImplItem<'_>, 88);
static_assert_size!(ImplItemKind<'_>, 40);
static_assert_size!(Item<'_>, 88);
static_assert_size!(ItemKind<'_>, 56);
static_assert_size!(Local<'_>, 64);
static_assert_size!(Param<'_>, 32);
static_assert_size!(Pat<'_>, 72);
@ -4105,8 +4105,8 @@ mod size_asserts {
static_assert_size!(Res, 12);
static_assert_size!(Stmt<'_>, 32);
static_assert_size!(StmtKind<'_>, 16);
static_assert_size!(TraitItem<'_>, 80);
static_assert_size!(TraitItemKind<'_>, 40);
static_assert_size!(TraitItem<'_>, 88);
static_assert_size!(TraitItemKind<'_>, 48);
static_assert_size!(Ty<'_>, 48);
static_assert_size!(TyKind<'_>, 32);
// tidy-alphabetical-end

View file

@ -595,7 +595,7 @@ fn compare_asyncness<'tcx>(
trait_m: ty::AssocItem,
delay: bool,
) -> Result<(), ErrorGuaranteed> {
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
if tcx.asyncness(trait_m.def_id).is_async() {
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
ty::Alias(ty::Opaque, ..) => {
// allow both `async fn foo()` and `fn foo() -> impl Future`

View file

@ -112,7 +112,7 @@ fn check_main_fn_ty(tcx: TyCtxt<'_>, main_def_id: DefId) {
}
let main_asyncness = tcx.asyncness(main_def_id);
if let hir::IsAsync::Async = main_asyncness {
if main_asyncness.is_async() {
let asyncness_span = main_fn_asyncness_span(tcx, main_def_id);
tcx.sess.emit_err(errors::MainFunctionAsync { span: main_span, asyncness: asyncness_span });
error = true;
@ -212,7 +212,7 @@ fn check_start_fn_ty(tcx: TyCtxt<'_>, start_def_id: DefId) {
});
error = true;
}
if let hir::IsAsync::Async = sig.header.asyncness {
if sig.header.asyncness.is_async() {
let span = tcx.def_span(it.owner_id);
tcx.sess.emit_err(errors::StartAsync { span: span });
error = true;

View file

@ -1213,7 +1213,7 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
&& let Some(generics) = self.tcx.hir().get_generics(self.tcx.local_parent(param_id))
&& let Some(param) = generics.params.iter().find(|p| p.def_id == param_id)
&& param.is_elided_lifetime()
&& let hir::IsAsync::NotAsync = self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id)
&& !self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id).is_async()
&& !self.tcx.features().anonymous_lifetime_in_impl_trait
{
let mut diag = rustc_session::parse::feature_err(

View file

@ -2304,7 +2304,7 @@ impl<'a> State<'a> {
match header.asyncness {
hir::IsAsync::NotAsync => {}
hir::IsAsync::Async => self.word_nbsp("async"),
hir::IsAsync::Async(_) => self.word_nbsp("async"),
}
self.print_unsafety(header.unsafety);

View file

@ -987,10 +987,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let bound_vars = self.tcx.late_bound_vars(fn_id);
let ty = self.tcx.erase_late_bound_regions(Binder::bind_with_vars(ty, bound_vars));
let ty = match self.tcx.asyncness(fn_id.owner) {
hir::IsAsync::Async => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
span_bug!(fn_decl.output.span(), "failed to get output type of async function")
}),
hir::IsAsync::NotAsync => ty,
ty::Asyncness::No => ty,
};
let ty = self.normalize(expr.span, ty);
if self.can_coerce(found, ty) {

View file

@ -41,7 +41,6 @@ use crate::{
},
EarlyContext, EarlyLintPass, LateContext, LateLintPass, Level, LintContext,
};
use hir::IsAsync;
use rustc_ast::attr;
use rustc_ast::tokenstream::{TokenStream, TokenTree};
use rustc_ast::visit::{FnCtxt, FnKind};
@ -1294,7 +1293,7 @@ impl<'tcx> LateLintPass<'tcx> for UngatedAsyncFnTrackCaller {
span: Span,
def_id: LocalDefId,
) {
if fn_kind.asyncness() == IsAsync::Async
if fn_kind.asyncness().is_async()
&& !cx.tcx.features().async_fn_track_caller
// Now, check if the function has the `#[track_caller]` attribute
&& let Some(attr) = cx.tcx.get_attr(def_id, sym::track_caller)

View file

@ -439,7 +439,7 @@ define_tables! {
coerce_unsized_info: Table<DefIndex, LazyValue<ty::adjustment::CoerceUnsizedInfo>>,
mir_const_qualif: Table<DefIndex, LazyValue<mir::ConstQualifs>>,
rendered_const: Table<DefIndex, LazyValue<String>>,
asyncness: Table<DefIndex, hir::IsAsync>,
asyncness: Table<DefIndex, ty::Asyncness>,
fn_arg_names: Table<DefIndex, LazyArray<Ident>>,
generator_kind: Table<DefIndex, LazyValue<hir::GeneratorKind>>,
trait_def: Table<DefIndex, LazyValue<ty::TraitDef>>,

View file

@ -205,9 +205,9 @@ fixed_size_enum! {
}
fixed_size_enum! {
hir::IsAsync {
( NotAsync )
( Async )
ty::Asyncness {
( Yes )
( No )
}
}

View file

@ -265,6 +265,7 @@ trivial! {
rustc_middle::ty::adjustment::CoerceUnsizedInfo,
rustc_middle::ty::AssocItem,
rustc_middle::ty::AssocItemContainer,
rustc_middle::ty::Asyncness,
rustc_middle::ty::BoundVariableKind,
rustc_middle::ty::DeducedParamAttrs,
rustc_middle::ty::Destructor,

View file

@ -731,7 +731,7 @@ rustc_queries! {
separate_provide_extern
}
query asyncness(key: DefId) -> hir::IsAsync {
query asyncness(key: DefId) -> ty::Asyncness {
desc { |tcx| "checking if the function is async: `{}`", tcx.def_path_str(key) }
separate_provide_extern
}

View file

@ -280,6 +280,19 @@ impl fmt::Display for ImplPolarity {
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable, Debug)]
#[derive(TypeFoldable, TypeVisitable)]
pub enum Asyncness {
Yes,
No,
}
impl Asyncness {
pub fn is_async(self) -> bool {
matches!(self, Asyncness::Yes)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Copy, Hash, Encodable, Decodable, HashStable)]
pub enum Visibility<Id = LocalDefId> {
/// Visible everywhere (including in other crates).

View file

@ -62,6 +62,7 @@ trivially_parameterized_over_tcx! {
crate::middle::resolve_bound_vars::ObjectLifetimeDefault,
crate::mir::ConstQualifs,
ty::AssocItemContainer,
ty::Asyncness,
ty::DeducedParamAttrs,
ty::Generics,
ty::ImplPolarity,

View file

@ -987,6 +987,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}
self.explain_hrtb_projection(&mut err, trait_predicate, obligation.param_env, &obligation.cause);
self.suggest_desugaring_async_fn_in_trait(&mut err, trait_ref);
// Return early if the trait is Debug or Display and the invocation
// originates within a standard library macro, because the output

View file

@ -104,7 +104,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn(sig, _, body_id), .. }) => {
self.describe_generator(*body_id).or_else(|| {
Some(match sig.header {
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async function",
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => {
"an async function"
}
_ => "a function",
})
})
@ -118,7 +120,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
..
}) => self.describe_generator(*body_id).or_else(|| {
Some(match sig.header {
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async method",
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => "an async method",
_ => "a method",
})
}),

View file

@ -414,6 +414,12 @@ pub trait TypeErrCtxtExt<'tcx> {
param_env: ty::ParamEnv<'tcx>,
cause: &ObligationCause<'tcx>,
);
fn suggest_desugaring_async_fn_in_trait(
&self,
err: &mut Diagnostic,
trait_ref: ty::PolyTraitRef<'tcx>,
);
}
fn predicate_constraint(generics: &hir::Generics<'_>, pred: ty::Predicate<'_>) -> (Span, String) {
@ -4100,6 +4106,136 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
});
}
}
fn suggest_desugaring_async_fn_in_trait(
&self,
err: &mut Diagnostic,
trait_ref: ty::PolyTraitRef<'tcx>,
) {
// Don't suggest if RTN is active -- we should prefer a where-clause bound instead.
if self.tcx.features().return_type_notation {
return;
}
let trait_def_id = trait_ref.def_id();
// Only suggest specifying auto traits
if !self.tcx.trait_is_auto(trait_def_id) {
return;
}
// Look for an RPITIT
let ty::Alias(ty::Projection, alias_ty) = trait_ref.self_ty().skip_binder().kind() else {
return;
};
let Some(ty::ImplTraitInTraitData::Trait { fn_def_id, opaque_def_id }) =
self.tcx.opt_rpitit_info(alias_ty.def_id)
else {
return;
};
let auto_trait = self.tcx.def_path_str(trait_def_id);
// ... which is a local function
let Some(fn_def_id) = fn_def_id.as_local() else {
// If it's not local, we can at least mention that the method is async, if it is.
if self.tcx.asyncness(fn_def_id).is_async() {
err.span_note(
self.tcx.def_span(fn_def_id),
format!(
"`{}::{}` is an `async fn` in trait, which does not \
automatically imply that its future is `{auto_trait}`",
alias_ty.trait_ref(self.tcx),
self.tcx.item_name(fn_def_id)
),
);
}
return;
};
let Some(hir::Node::TraitItem(item)) = self.tcx.hir().find_by_def_id(fn_def_id) else {
return;
};
// ... whose signature is `async` (i.e. this is an AFIT)
let (sig, body) = item.expect_fn();
let hir::IsAsync::Async(async_span) = sig.header.asyncness else {
return;
};
let Ok(async_span) =
self.tcx.sess.source_map().span_extend_while(async_span, |c| c.is_whitespace())
else {
return;
};
let hir::FnRetTy::Return(hir::Ty { kind: hir::TyKind::OpaqueDef(def, ..), .. }) =
sig.decl.output
else {
// This should never happen, but let's not ICE.
return;
};
// Check that this is *not* a nested `impl Future` RPIT in an async fn
// (i.e. `async fn foo() -> impl Future`)
if def.owner_id.to_def_id() != opaque_def_id {
return;
}
let future = self.tcx.hir().item(*def).expect_opaque_ty();
let Some(hir::GenericBound::LangItemTrait(_, _, _, generics)) = future.bounds.get(0) else {
// `async fn` should always lower to a lang item bound... but don't ICE.
return;
};
let Some(hir::TypeBindingKind::Equality { term: hir::Term::Ty(future_output_ty) }) =
generics.bindings.get(0).map(|binding| binding.kind)
else {
// Also should never happen.
return;
};
let function_name = self.tcx.def_path_str(fn_def_id);
let mut sugg = if future_output_ty.span.is_empty() {
vec![
(async_span, String::new()),
(
future_output_ty.span,
format!(" -> impl std::future::Future<Output = ()> + {auto_trait}"),
),
]
} else {
vec![
(
future_output_ty.span.shrink_to_lo(),
"impl std::future::Future<Output = ".to_owned(),
),
(future_output_ty.span.shrink_to_hi(), format!("> + {auto_trait}")),
(async_span, String::new()),
]
};
// If there's a body, we also need to wrap it in `async {}`
if let hir::TraitFn::Provided(body) = body {
let body = self.tcx.hir().body(*body);
let body_span = body.value.span;
let body_span_without_braces =
body_span.with_lo(body_span.lo() + BytePos(1)).with_hi(body_span.hi() - BytePos(1));
if body_span_without_braces.is_empty() {
sugg.push((body_span_without_braces, " async {} ".to_owned()));
} else {
sugg.extend([
(body_span_without_braces.shrink_to_lo(), "async {".to_owned()),
(body_span_without_braces.shrink_to_hi(), "} ".to_owned()),
]);
}
}
err.multipart_suggestion(
format!(
"`{auto_trait}` can be made part of the associated future's \
guarantees for all implementations of `{function_name}`"
),
sugg,
Applicability::MachineApplicable,
);
}
}
/// Add a hint to add a missing borrow or remove an unnecessary one.

View file

@ -296,9 +296,12 @@ fn issue33140_self_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<EarlyBinder<Ty<'
}
/// Check if a function is async.
fn asyncness(tcx: TyCtxt<'_>, def_id: LocalDefId) -> hir::IsAsync {
fn asyncness(tcx: TyCtxt<'_>, def_id: LocalDefId) -> ty::Asyncness {
let node = tcx.hir().get_by_def_id(def_id);
node.fn_sig().map_or(hir::IsAsync::NotAsync, |sig| sig.header.asyncness)
node.fn_sig().map_or(ty::Asyncness::No, |sig| match sig.header.asyncness {
hir::IsAsync::Async(_) => ty::Asyncness::Yes,
hir::IsAsync::NotAsync => ty::Asyncness::No,
})
}
fn unsizing_params_for_adt<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> BitSet<u32> {