diff --git a/src/librustdoc/doctest/make.rs b/src/librustdoc/doctest/make.rs index cb14608b35a..810f53636ce 100644 --- a/src/librustdoc/doctest/make.rs +++ b/src/librustdoc/doctest/make.rs @@ -5,15 +5,17 @@ use std::fmt::{self, Write as _}; use std::io; use std::sync::Arc; -use rustc_ast::{self as ast, HasAttrs}; +use rustc_ast::token::{Delimiter, TokenKind}; +use rustc_ast::tokenstream::TokenTree; +use rustc_ast::{self as ast, HasAttrs, StmtKind}; use rustc_errors::ColorConfig; use rustc_errors::emitter::stderr_destination; use rustc_parse::new_parser_from_source_str; use rustc_session::parse::ParseSess; -use rustc_span::FileName; use rustc_span::edition::Edition; use rustc_span::source_map::SourceMap; use rustc_span::symbol::sym; +use rustc_span::{FileName, kw}; use tracing::debug; use super::GlobalTestOptions; @@ -319,7 +321,7 @@ fn parse_source(source: &str, crate_name: &Option<&str>) -> Result source.len() { hi = source.len(); @@ -351,11 +353,8 @@ fn parse_source(source: &str, crate_name: &Option<&str>) -> Result { - check_item(item, info, crate_name, false) - } - _ => {} + if let StmtKind::Item(ref item) = stmt.kind { + check_item(item, info, crate_name, false) } } } @@ -381,8 +380,6 @@ fn parse_source(source: &str, crate_name: &Option<&str>) -> Result {parsed:#?}"); - let result = match parsed { Ok(Some(ref item)) if let ast::ItemKind::Fn(ref fn_item) = item.kind @@ -416,11 +413,31 @@ fn parse_source(source: &str, crate_name: &Option<&str>) -> Result check_item(&item, &mut info, crate_name, true), - ast::StmtKind::Expr(ref expr) if matches!(expr.kind, ast::ExprKind::Err(_)) => { + StmtKind::Item(ref item) => check_item(&item, &mut info, crate_name, true), + StmtKind::Expr(ref expr) if matches!(expr.kind, ast::ExprKind::Err(_)) => { cancel_error_count(&psess); return Err(()); } + StmtKind::MacCall(ref mac_call) if !info.has_main_fn => { + let mut iter = mac_call.mac.args.tokens.iter(); + + while let Some(token) = iter.next() { + if let TokenTree::Token(token, _) = token + && let TokenKind::Ident(name, _) = token.kind + && name == kw::Fn + && let Some(TokenTree::Token(fn_token, _)) = iter.peek() + && let TokenKind::Ident(fn_name, _) = fn_token.kind + && fn_name == sym::main + && let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, _)) = { + iter.next(); + iter.peek() + } + { + info.has_main_fn = true; + break; + } + } + } _ => {} } @@ -433,7 +450,7 @@ fn parse_source(source: &str, crate_name: &Option<&str>) -> Result