diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 61d91874e61..6f68f679cc0 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -617,33 +617,47 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::ExprKind::Closure(c) }; - let generator_hir_id = self.lower_node_id(closure_node_id); - // FIXME: only add track caller if the parent is track_caller - self.lower_attrs( - generator_hir_id, - &[Attribute { - kind: AttrKind::Normal(ptr::P(NormalAttr { - item: AttrItem { - path: Path::from_ident(Ident::new(sym::track_caller, span)), - args: MacArgs::Empty, - tokens: None, - }, - tokens: None, - })), - id: self.tcx.sess.parse_sess.attr_id_generator.mk_attr_id(), - style: AttrStyle::Outer, - span, - }], - ); - let generator = hir::Expr { - hir_id: generator_hir_id, - kind: generator_kind, - span: self.lower_span(span), - }; - - // `future::from_generator`: + let mut parent_has_track_caller = false; + for attrs in self.attrs.values() { + for attr in attrs.into_iter() { + if attr.has_name(sym::track_caller) { + parent_has_track_caller = true; + break; + } + } + if parent_has_track_caller { + break; + } + } let unstable_span = self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone()); + + let hir_id = if parent_has_track_caller { + let generator_hir_id = self.lower_node_id(closure_node_id); + self.lower_attrs( + generator_hir_id, + &[Attribute { + kind: AttrKind::Normal(ptr::P(NormalAttr { + item: AttrItem { + path: Path::from_ident(Ident::new(sym::track_caller, span)), + args: MacArgs::Empty, + tokens: None, + }, + tokens: None, + })), + id: self.tcx.sess.parse_sess.attr_id_generator.mk_attr_id(), + style: AttrStyle::Outer, + span: unstable_span, + }], + ); + generator_hir_id + } else { + self.lower_node_id(closure_node_id) + }; + + let generator = hir::Expr { hir_id, kind: generator_kind, span: self.lower_span(span) }; + + // `future::from_generator`: let gen_future = self.expr_lang_item_path( unstable_span, hir::LangItem::FromGenerator, diff --git a/library/core/src/future/mod.rs b/library/core/src/future/mod.rs index 6487aa08859..107cf92c1c0 100644 --- a/library/core/src/future/mod.rs +++ b/library/core/src/future/mod.rs @@ -82,6 +82,7 @@ where impl> Future for GenFuture { type Output = T::Return; + #[track_caller] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // SAFETY: Safe because we're !Unpin + !Drop, and this is just a field projection. let gen = unsafe { Pin::map_unchecked_mut(self, |s| &mut s.0) }; diff --git a/src/test/ui/async-await/panic-track-caller.rs b/src/test/ui/async-await/track-caller/panic-track-caller.rs similarity index 93% rename from src/test/ui/async-await/panic-track-caller.rs rename to src/test/ui/async-await/track-caller/panic-track-caller.rs index 76776d41c57..4e659da9ee0 100644 --- a/src/test/ui/async-await/panic-track-caller.rs +++ b/src/test/ui/async-await/track-caller/panic-track-caller.rs @@ -70,6 +70,6 @@ fn panicked_at(f: impl FnOnce() + panic::UnwindSafe) -> u32 { } fn main() { - assert_eq!(panicked_at(|| block_on(foo())), 39); - assert_eq!(panicked_at(|| block_on(foo_track_caller())), 52); + assert_eq!(panicked_at(|| block_on(foo())), 40); + assert_eq!(panicked_at(|| block_on(foo_track_caller())), 53); }