refactor: simplify function-info gathering

This commit is contained in:
HaeNoe 2025-04-03 22:47:30 +02:00
parent 63e825e52a
commit bf69443a9f
No known key found for this signature in database

View file

@ -17,7 +17,7 @@ mod llvm_enzyme {
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
MetaItemInner, PatKind, QSelf, TyKind,
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, kw, sym};
@ -72,6 +72,16 @@ mod llvm_enzyme {
}
}
// Get information about the function the macro is applied to
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
}
_ => None,
}
}
pub(crate) fn from_ast(
ecx: &mut ExtCtxt<'_>,
meta_item: &ThinVec<MetaItemInner>,
@ -201,49 +211,24 @@ mod llvm_enzyme {
let dcx = ecx.sess.dcx();
// first get information about the annotable item:
let (sig, vis, primal) = match &item {
Annotatable::Item(iitem) => {
let (sig, ident) = match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
(sig.clone(), iitem.vis.clone(), ident.clone())
}
let Some((vis, sig, primal)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
let (sig, ident) = match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
}
};
(sig.clone(), assoc_item.vis.clone(), ident.clone())
}
Annotatable::Stmt(stmt) => {
let (sig, vis, ident) = match &stmt.kind {
ast::StmtKind::Item(iitem) => match &iitem.kind {
ast::ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
(sig.clone(), iitem.vis.clone(), ident.clone())
}
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
},
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
(sig, vis, ident)
}
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
_ => None,
}
}
_ => None,
}) else {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
};
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {