add autodiff batching middle-end
This commit is contained in:
parent
087ffd73bf
commit
e0c8ead880
1 changed files with 28 additions and 4 deletions
|
@ -2,7 +2,7 @@ use std::str::FromStr;
|
|||
|
||||
use rustc_abi::ExternAbi;
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_ast::{MetaItem, MetaItemInner, attr};
|
||||
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
|
||||
use rustc_attr_parsing::ReprAttr::ReprAlign;
|
||||
use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
|
@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
return Some(AutoDiffAttrs::source());
|
||||
}
|
||||
|
||||
let [mode, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
|
||||
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
|
||||
};
|
||||
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
|
||||
p1.segments.first().unwrap().ident
|
||||
|
@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
}
|
||||
};
|
||||
|
||||
let width: u32 = match width_meta {
|
||||
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
|
||||
let w = p1.segments.first().unwrap().ident;
|
||||
match w.as_str().parse() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(w.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
}
|
||||
MetaItemInner::Lit(lit) => {
|
||||
if let LitKind::Int(val, _) = lit.kind {
|
||||
match val.get().try_into() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(lit.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
span_bug!(lit.span, "rustc_autodiff width should be an integer");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// First read the ret symbol from the attribute
|
||||
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
|
||||
p1.segments.first().unwrap().ident
|
||||
|
@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
}
|
||||
}
|
||||
|
||||
Some(AutoDiffAttrs { mode, width: 1, ret_activity, input_activity: arg_activities })
|
||||
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
|
||||
}
|
||||
|
||||
pub(crate) fn provide(providers: &mut Providers) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue