add autodiff batching middle-end

This commit is contained in:
Manuel Drehwald 2025-04-03 17:21:21 -04:00
parent 087ffd73bf
commit e0c8ead880

View file

@ -2,7 +2,7 @@ use std::str::FromStr;
use rustc_abi::ExternAbi;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::{MetaItem, MetaItemInner, attr};
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
use rustc_attr_parsing::ReprAttr::ReprAlign;
use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_data_structures::fx::FxHashMap;
@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
return Some(AutoDiffAttrs::source());
}
let [mode, input_activities @ .., ret_activity] = &list[..] else {
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
};
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
p1.segments.first().unwrap().ident
@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
}
};
let width: u32 = match width_meta {
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
let w = p1.segments.first().unwrap().ident;
match w.as_str().parse() {
Ok(val) => val,
Err(_) => {
span_bug!(w.span, "rustc_autodiff width should fit u32");
}
}
}
MetaItemInner::Lit(lit) => {
if let LitKind::Int(val, _) = lit.kind {
match val.get().try_into() {
Ok(val) => val,
Err(_) => {
span_bug!(lit.span, "rustc_autodiff width should fit u32");
}
}
} else {
span_bug!(lit.span, "rustc_autodiff width should be an integer");
}
}
};
// First read the ret symbol from the attribute
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
p1.segments.first().unwrap().ident
@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
}
}
Some(AutoDiffAttrs { mode, width: 1, ret_activity, input_activity: arg_activities })
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
}
pub(crate) fn provide(providers: &mut Providers) {