
Autodiff batching Enzyme supports batching, which is especially known from the ML side when training neural networks. There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights. That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations. Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size, and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of ```rs for i in 0..100 { df(x[i], y[i], 1234); } ``` You can now do ```rs for i in 0..100.step_by(4) { df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234); } ``` which will give the same results, but allows better compiler optimizations. See the testcase for details. There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days. I will also add more tests for both modes. For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU I'll also add some other docs to the dev guide and user docs in another PR. r? ghost Tracking: - https://github.com/rust-lang/rust/issues/124509 - https://github.com/rust-lang/rust/issues/135283
933 lines
38 KiB
Rust
933 lines
38 KiB
Rust
//! This module contains the implementation of the `#[autodiff]` attribute.
|
|
//! Currently our linter isn't smart enough to see that each import is used in one of the two
|
|
//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
|
|
//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
|
|
|
|
mod llvm_enzyme {
|
|
use std::str::FromStr;
|
|
use std::string::String;
|
|
|
|
use rustc_ast::expand::autodiff_attrs::{
|
|
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
|
|
valid_ty_for_activity,
|
|
};
|
|
use rustc_ast::ptr::P;
|
|
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
|
|
use rustc_ast::tokenstream::*;
|
|
use rustc_ast::visit::AssocCtxt::*;
|
|
use rustc_ast::{
|
|
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
|
MetaItemInner, PatKind, QSelf, TyKind,
|
|
};
|
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
|
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
|
use thin_vec::{ThinVec, thin_vec};
|
|
use tracing::{debug, trace};
|
|
|
|
use crate::errors;
|
|
|
|
pub(crate) fn outer_normal_attr(
|
|
kind: &P<rustc_ast::NormalAttr>,
|
|
id: rustc_ast::AttrId,
|
|
span: Span,
|
|
) -> rustc_ast::Attribute {
|
|
let style = rustc_ast::AttrStyle::Outer;
|
|
let kind = rustc_ast::AttrKind::Normal(kind.clone());
|
|
rustc_ast::Attribute { kind, id, style, span }
|
|
}
|
|
|
|
// If we have a default `()` return type or explicitley `()` return type,
|
|
// then we often can skip doing some work.
|
|
fn has_ret(ty: &FnRetTy) -> bool {
|
|
match ty {
|
|
FnRetTy::Ty(ty) => !ty.kind.is_unit(),
|
|
FnRetTy::Default(_) => false,
|
|
}
|
|
}
|
|
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
|
if let Some(l) = x.lit() {
|
|
match l.kind {
|
|
ast::LitKind::Int(val, _) => {
|
|
// get an Ident from a lit
|
|
return rustc_span::Ident::from_str(val.get().to_string().as_str());
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let segments = &x.meta_item().unwrap().path.segments;
|
|
assert!(segments.len() == 1);
|
|
segments[0].ident
|
|
}
|
|
|
|
fn name(x: &MetaItemInner) -> String {
|
|
first_ident(x).name.to_string()
|
|
}
|
|
|
|
fn width(x: &MetaItemInner) -> Option<u128> {
|
|
let lit = x.lit()?;
|
|
match lit.kind {
|
|
ast::LitKind::Int(x, _) => Some(x.get()),
|
|
_ => return None,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn from_ast(
|
|
ecx: &mut ExtCtxt<'_>,
|
|
meta_item: &ThinVec<MetaItemInner>,
|
|
has_ret: bool,
|
|
) -> AutoDiffAttrs {
|
|
let dcx = ecx.sess.dcx();
|
|
let mode = name(&meta_item[1]);
|
|
let Ok(mode) = DiffMode::from_str(&mode) else {
|
|
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
|
return AutoDiffAttrs::error();
|
|
};
|
|
|
|
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
|
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
|
let mut first_activity = 2;
|
|
|
|
let width = if let [_, _, x, ..] = &meta_item[..]
|
|
&& let Some(x) = width(x)
|
|
{
|
|
first_activity = 3;
|
|
match x.try_into() {
|
|
Ok(x) => x,
|
|
Err(_) => {
|
|
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
|
span: meta_item[2].span(),
|
|
width: x,
|
|
});
|
|
return AutoDiffAttrs::error();
|
|
}
|
|
}
|
|
} else {
|
|
1
|
|
};
|
|
|
|
let mut activities: Vec<DiffActivity> = vec![];
|
|
let mut errors = false;
|
|
for x in &meta_item[first_activity..] {
|
|
let activity_str = name(&x);
|
|
let res = DiffActivity::from_str(&activity_str);
|
|
match res {
|
|
Ok(x) => activities.push(x),
|
|
Err(_) => {
|
|
dcx.emit_err(errors::AutoDiffUnknownActivity {
|
|
span: x.span(),
|
|
act: activity_str,
|
|
});
|
|
errors = true;
|
|
}
|
|
};
|
|
}
|
|
if errors {
|
|
return AutoDiffAttrs::error();
|
|
}
|
|
|
|
// If a return type exist, we need to split the last activity,
|
|
// otherwise we return None as placeholder.
|
|
let (ret_activity, input_activity) = if has_ret {
|
|
let Some((last, rest)) = activities.split_last() else {
|
|
unreachable!(
|
|
"should not be reachable because we counted the number of activities previously"
|
|
);
|
|
};
|
|
(last, rest)
|
|
} else {
|
|
(&DiffActivity::None, activities.as_slice())
|
|
};
|
|
|
|
AutoDiffAttrs {
|
|
mode,
|
|
width,
|
|
ret_activity: *ret_activity,
|
|
input_activity: input_activity.to_vec(),
|
|
}
|
|
}
|
|
|
|
fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
|
|
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
|
let val = first_ident(t);
|
|
let t = Token::from_ast_ident(val);
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
|
}
|
|
|
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
|
/// type-checking and can be called by users. The function body of the placeholder function will
|
|
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
|
/// should just prevent early inlining and optimizations which alter the function signature.
|
|
/// The exact signature of the generated function depends on the configuration provided by the
|
|
/// user, but here is an example:
|
|
///
|
|
/// ```
|
|
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
|
|
/// fn sin(x: &Box<f32>) -> f32 {
|
|
/// f32::sin(**x)
|
|
/// }
|
|
/// ```
|
|
/// which becomes expanded to:
|
|
/// ```
|
|
/// #[rustc_autodiff]
|
|
/// #[inline(never)]
|
|
/// fn sin(x: &Box<f32>) -> f32 {
|
|
/// f32::sin(**x)
|
|
/// }
|
|
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
|
|
/// #[inline(never)]
|
|
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
|
|
/// unsafe {
|
|
/// asm!("NOP");
|
|
/// };
|
|
/// ::core::hint::black_box(sin(x));
|
|
/// ::core::hint::black_box((dx, dret));
|
|
/// ::core::hint::black_box(sin(x))
|
|
/// }
|
|
/// ```
|
|
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
|
/// in CI.
|
|
pub(crate) fn expand(
|
|
ecx: &mut ExtCtxt<'_>,
|
|
expand_span: Span,
|
|
meta_item: &ast::MetaItem,
|
|
mut item: Annotatable,
|
|
) -> Vec<Annotatable> {
|
|
if cfg!(not(llvm_enzyme)) {
|
|
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
|
return vec![item];
|
|
}
|
|
let dcx = ecx.sess.dcx();
|
|
// first get the annotable item:
|
|
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
|
|
Annotatable::Item(iitem) => {
|
|
let (ident, sig) = match &iitem.kind {
|
|
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
|
|
_ => {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
};
|
|
(*ident, sig.clone(), false)
|
|
}
|
|
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
|
|
let (ident, sig) = match &assoc_item.kind {
|
|
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
|
|
_ => {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
};
|
|
(*ident, sig.clone(), true)
|
|
}
|
|
_ => {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
};
|
|
|
|
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
|
|
ast::MetaItemKind::List(ref vec) => vec.clone(),
|
|
_ => {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
};
|
|
|
|
let has_ret = has_ret(&sig.decl.output);
|
|
let sig_span = ecx.with_call_site_ctxt(sig.span);
|
|
|
|
let vis = match &item {
|
|
Annotatable::Item(iitem) => iitem.vis.clone(),
|
|
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
|
|
_ => {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
};
|
|
|
|
// create TokenStream from vec elemtents:
|
|
// meta_item doesn't have a .tokens field
|
|
let mut ts: Vec<TokenTree> = vec![];
|
|
if meta_item_vec.len() < 2 {
|
|
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
|
// input and output args.
|
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
|
return vec![item];
|
|
}
|
|
|
|
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
|
|
|
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
|
// If it is not given, we default to 1 (scalar mode).
|
|
let start_position;
|
|
let kind: LitKind = LitKind::Integer;
|
|
let symbol;
|
|
if meta_item_vec.len() >= 3
|
|
&& let Some(width) = width(&meta_item_vec[2])
|
|
{
|
|
start_position = 3;
|
|
symbol = Symbol::intern(&width.to_string());
|
|
} else {
|
|
start_position = 2;
|
|
symbol = sym::integer(1);
|
|
}
|
|
let l: Lit = Lit { kind, symbol, suffix: None };
|
|
let t = Token::new(TokenKind::Literal(l), Span::default());
|
|
let comma = Token::new(TokenKind::Comma, Span::default());
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
|
|
|
for t in meta_item_vec.clone()[start_position..].iter() {
|
|
meta_item_inner_to_ts(t, &mut ts);
|
|
}
|
|
|
|
if !has_ret {
|
|
// We don't want users to provide a return activity if the function doesn't return anything.
|
|
// For simplicity, we just add a dummy token to the end of the list.
|
|
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
ts.push(TokenTree::Token(comma, Spacing::Alone));
|
|
}
|
|
// We remove the last, trailing comma.
|
|
ts.pop();
|
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
|
|
|
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
|
if !x.is_active() {
|
|
// We encountered an error, so we return the original item.
|
|
// This allows us to potentially parse other attributes.
|
|
return vec![item];
|
|
}
|
|
let span = ecx.with_def_site_ctxt(expand_span);
|
|
|
|
let n_active: u32 = x
|
|
.input_activity
|
|
.iter()
|
|
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
|
|
.count() as u32;
|
|
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
|
|
let d_body = gen_enzyme_body(
|
|
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
|
|
);
|
|
|
|
// The first element of it is the name of the function to be generated
|
|
let asdf = Box::new(ast::Fn {
|
|
defaultness: ast::Defaultness::Final,
|
|
sig: d_sig,
|
|
ident: first_ident(&meta_item_vec[0]),
|
|
generics: Generics::default(),
|
|
contract: None,
|
|
body: Some(d_body),
|
|
define_opaque: None,
|
|
});
|
|
let mut rustc_ad_attr =
|
|
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
|
|
|
|
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
|
|
Token::new(TokenKind::Ident(sym::never, false.into()), span),
|
|
Spacing::Joint,
|
|
)];
|
|
let never_arg = ast::DelimArgs {
|
|
dspan: ast::tokenstream::DelimSpan::from_single(span),
|
|
delim: ast::token::Delimiter::Parenthesis,
|
|
tokens: ast::tokenstream::TokenStream::from_iter(ts2),
|
|
};
|
|
let inline_item = ast::AttrItem {
|
|
unsafety: ast::Safety::Default,
|
|
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
|
|
args: ast::AttrArgs::Delimited(never_arg),
|
|
tokens: None,
|
|
};
|
|
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
|
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
|
let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
|
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
|
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
|
|
|
|
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
|
|
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
|
|
match (attr, item) {
|
|
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
|
|
let a = &a.item.path;
|
|
let b = &b.item.path;
|
|
a.segments.len() == b.segments.len()
|
|
&& a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
// Don't add it multiple times:
|
|
let orig_annotatable: Annotatable = match item {
|
|
Annotatable::Item(ref mut iitem) => {
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
|
iitem.attrs.push(attr);
|
|
}
|
|
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
|
|
iitem.attrs.push(inline_never.clone());
|
|
}
|
|
Annotatable::Item(iitem.clone())
|
|
}
|
|
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => {
|
|
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
|
|
assoc_item.attrs.push(attr);
|
|
}
|
|
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
|
|
assoc_item.attrs.push(inline_never.clone());
|
|
}
|
|
Annotatable::AssocItem(assoc_item.clone(), i)
|
|
}
|
|
_ => {
|
|
unreachable!("annotatable kind checked previously")
|
|
}
|
|
};
|
|
// Now update for d_fn
|
|
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
|
|
dspan: DelimSpan::dummy(),
|
|
delim: rustc_ast::token::Delimiter::Parenthesis,
|
|
tokens: ts,
|
|
});
|
|
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
|
|
let d_annotatable = if is_impl {
|
|
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
|
let d_fn = P(ast::AssocItem {
|
|
attrs: thin_vec![d_attr, inline_never],
|
|
id: ast::DUMMY_NODE_ID,
|
|
span,
|
|
vis,
|
|
kind: assoc_item,
|
|
tokens: None,
|
|
});
|
|
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
|
|
} else {
|
|
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
|
|
d_fn.vis = vis;
|
|
Annotatable::Item(d_fn)
|
|
};
|
|
|
|
return vec![orig_annotatable, d_annotatable];
|
|
}
|
|
|
|
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
|
|
// mutable references or ptrs, because Enzyme will write into them.
|
|
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
|
|
let mut ty = ty.clone();
|
|
match ty.kind {
|
|
TyKind::Ptr(ref mut mut_ty) => {
|
|
mut_ty.mutbl = ast::Mutability::Mut;
|
|
}
|
|
TyKind::Ref(_, ref mut mut_ty) => {
|
|
mut_ty.mutbl = ast::Mutability::Mut;
|
|
}
|
|
_ => {
|
|
panic!("unsupported type: {:?}", ty);
|
|
}
|
|
}
|
|
ty
|
|
}
|
|
|
|
// Will generate a body of the type:
|
|
// ```
|
|
// {
|
|
// unsafe {
|
|
// asm!("NOP");
|
|
// }
|
|
// ::core::hint::black_box(primal(args));
|
|
// ::core::hint::black_box((args, ret));
|
|
// <This part remains to be done by following function>
|
|
// }
|
|
// ```
|
|
fn init_body_helper(
|
|
ecx: &ExtCtxt<'_>,
|
|
span: Span,
|
|
primal: Ident,
|
|
new_names: &[String],
|
|
sig_span: Span,
|
|
new_decl_span: Span,
|
|
idents: &[Ident],
|
|
errored: bool,
|
|
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
|
|
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
|
|
let noop = ast::InlineAsm {
|
|
asm_macro: ast::AsmMacro::Asm,
|
|
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
|
|
template_strs: Box::new([]),
|
|
operands: vec![],
|
|
clobber_abis: vec![],
|
|
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
|
|
line_spans: vec![],
|
|
};
|
|
let noop_expr = ecx.expr_asm(span, P(noop));
|
|
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
|
|
let unsf_block = ast::Block {
|
|
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
|
|
id: ast::DUMMY_NODE_ID,
|
|
tokens: None,
|
|
rules: unsf,
|
|
span,
|
|
};
|
|
let unsf_expr = ecx.expr_block(P(unsf_block));
|
|
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
|
|
let primal_call = gen_primal_call(ecx, span, primal, idents);
|
|
let black_box_primal_call = ecx.expr_call(
|
|
new_decl_span,
|
|
blackbox_call_expr.clone(),
|
|
thin_vec![primal_call.clone()],
|
|
);
|
|
let tup_args = new_names
|
|
.iter()
|
|
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
|
|
.collect();
|
|
|
|
let black_box_remaining_args = ecx.expr_call(
|
|
sig_span,
|
|
blackbox_call_expr.clone(),
|
|
thin_vec![ecx.expr_tuple(sig_span, tup_args)],
|
|
);
|
|
|
|
let mut body = ecx.block(span, ThinVec::new());
|
|
body.stmts.push(ecx.stmt_semi(unsf_expr));
|
|
|
|
// This uses primal args which won't be available if we errored before
|
|
if !errored {
|
|
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
|
|
}
|
|
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
|
|
|
|
(body, primal_call, black_box_primal_call, blackbox_call_expr)
|
|
}
|
|
|
|
/// We only want this function to type-check, since we will replace the body
|
|
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
|
|
/// so instead we manually build something that should pass the type checker.
|
|
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
|
|
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
|
|
/// bug would ever try to accidentially differentiate this placeholder function body.
|
|
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
|
|
/// from optimizing any arguments away.
|
|
fn gen_enzyme_body(
|
|
ecx: &ExtCtxt<'_>,
|
|
x: &AutoDiffAttrs,
|
|
n_active: u32,
|
|
sig: &ast::FnSig,
|
|
d_sig: &ast::FnSig,
|
|
primal: Ident,
|
|
new_names: &[String],
|
|
span: Span,
|
|
sig_span: Span,
|
|
idents: Vec<Ident>,
|
|
errored: bool,
|
|
) -> P<ast::Block> {
|
|
let new_decl_span = d_sig.span;
|
|
|
|
// Just adding some default inline-asm and black_box usages to prevent early inlining
|
|
// and optimizations which alter the function signature.
|
|
//
|
|
// The bb_primal_call is the black_box call of the primal function. We keep it around,
|
|
// since it has the convenient property of returning the type of the primal function,
|
|
// Remember, we only care to match types here.
|
|
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
|
|
// to prevent rustc from propagating it into the caller.
|
|
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
|
|
ecx,
|
|
span,
|
|
primal,
|
|
new_names,
|
|
sig_span,
|
|
new_decl_span,
|
|
&idents,
|
|
errored,
|
|
);
|
|
|
|
if !has_ret(&d_sig.decl.output) {
|
|
// there is no return type that we have to match, () works fine.
|
|
return body;
|
|
}
|
|
|
|
// Everything from here onwards just tries to fullfil the return type. Fun!
|
|
|
|
// having an active-only return means we'll drop the original return type.
|
|
// So that can be treated identical to not having one in the first place.
|
|
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
|
|
|
if primal_ret && n_active == 0 && x.mode.is_rev() {
|
|
// We only have the primal ret.
|
|
body.stmts.push(ecx.stmt_expr(bb_primal_call));
|
|
return body;
|
|
}
|
|
|
|
if !primal_ret && n_active == 1 {
|
|
// Again no tuple return, so return default float val.
|
|
let ty = match d_sig.decl.output {
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
FnRetTy::Default(span) => {
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
}
|
|
};
|
|
let arg = ty.kind.is_simple_path().unwrap();
|
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
|
let tmp = ecx.def_site_path(&sl);
|
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
|
body.stmts.push(ecx.stmt_expr(default_call_expr));
|
|
return body;
|
|
}
|
|
|
|
let mut exprs: P<ast::Expr> = primal_call.clone();
|
|
let d_ret_ty = match d_sig.decl.output {
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
FnRetTy::Default(span) => {
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
}
|
|
};
|
|
|
|
if x.mode.is_fwd() {
|
|
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
|
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
|
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
|
|
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
|
|
if x.ret_activity == DiffActivity::Const {
|
|
// Here we call the primal function, since our dummy function has the same return
|
|
// type due to the Const return activity.
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
|
} else {
|
|
let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
|
|
let y =
|
|
ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
|
|
let default_call_expr = ecx.expr(span, y);
|
|
let default_call_expr =
|
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
|
|
}
|
|
} else if x.mode.is_rev() {
|
|
if x.width == 1 {
|
|
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
|
|
match d_ret_ty.kind {
|
|
TyKind::Tup(ref args) => {
|
|
// We have a tuple return type. We need to create a tuple of the same size
|
|
// and fill it with default values.
|
|
let mut exprs2 = thin_vec![exprs];
|
|
for arg in args.iter().skip(1) {
|
|
let arg = arg.kind.is_simple_path().unwrap();
|
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
|
let tmp = ecx.def_site_path(&sl);
|
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
|
let default_call_expr =
|
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
|
exprs2.push(default_call_expr);
|
|
}
|
|
exprs = ecx.expr_tuple(new_decl_span, exprs2);
|
|
}
|
|
_ => {
|
|
// Interestingly, even the `-> ArbitraryType` case
|
|
// ends up getting matched and handled correctly above,
|
|
// so we don't have to handle any other case for now.
|
|
panic!("Unsupported return type: {:?}", d_ret_ty);
|
|
}
|
|
}
|
|
}
|
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
|
} else {
|
|
unreachable!("Unsupported mode: {:?}", x.mode);
|
|
}
|
|
|
|
body.stmts.push(ecx.stmt_expr(exprs));
|
|
|
|
body
|
|
}
|
|
|
|
fn gen_primal_call(
|
|
ecx: &ExtCtxt<'_>,
|
|
span: Span,
|
|
primal: Ident,
|
|
idents: &[Ident],
|
|
) -> P<ast::Expr> {
|
|
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
|
|
if has_self {
|
|
let args: ThinVec<_> =
|
|
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
|
let self_expr = ecx.expr_self(span);
|
|
ecx.expr_method_call(span, self_expr, primal, args)
|
|
} else {
|
|
let args: ThinVec<_> =
|
|
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
|
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
|
|
ecx.expr_call(span, primal_call_expr, args)
|
|
}
|
|
}
|
|
|
|
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
|
|
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
|
|
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
|
|
// zero-initialized by Enzyme).
|
|
// Each argument of the primal function (and the return type if existing) must be annotated with an
|
|
// activity.
|
|
//
|
|
// Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
|
|
// both), we emit an error and return the original signature. This allows us to continue parsing.
|
|
// FIXME(Sa4dUs): make individual activities' span available so errors
|
|
// can point to only the activity instead of the entire attribute
|
|
fn gen_enzyme_decl(
|
|
ecx: &ExtCtxt<'_>,
|
|
sig: &ast::FnSig,
|
|
x: &AutoDiffAttrs,
|
|
span: Span,
|
|
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
|
|
let dcx = ecx.sess.dcx();
|
|
let has_ret = has_ret(&sig.decl.output);
|
|
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
|
|
let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
|
|
if sig_args != num_activities {
|
|
dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
|
|
span,
|
|
expected: sig_args,
|
|
found: num_activities,
|
|
});
|
|
// This is not the right signature, but we can continue parsing.
|
|
return (sig.clone(), vec![], vec![], true);
|
|
}
|
|
assert!(sig.decl.inputs.len() == x.input_activity.len());
|
|
assert!(has_ret == x.has_ret_activity());
|
|
let mut d_decl = sig.decl.clone();
|
|
let mut d_inputs = Vec::new();
|
|
let mut new_inputs = Vec::new();
|
|
let mut idents = Vec::new();
|
|
let mut act_ret = ThinVec::new();
|
|
|
|
// We have two loops, a first one just to check the activities and types and possibly report
|
|
// multiple errors in one compilation session.
|
|
let mut errors = false;
|
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
|
if !valid_input_activity(x.mode, *activity) {
|
|
dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
|
|
span,
|
|
mode: x.mode.to_string(),
|
|
act: activity.to_string(),
|
|
});
|
|
errors = true;
|
|
}
|
|
if !valid_ty_for_activity(&arg.ty, *activity) {
|
|
dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
|
|
span: arg.ty.span,
|
|
act: activity.to_string(),
|
|
});
|
|
errors = true;
|
|
}
|
|
}
|
|
|
|
if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
|
|
dcx.emit_err(errors::AutoDiffInvalidRetAct {
|
|
span,
|
|
mode: x.mode.to_string(),
|
|
act: x.ret_activity.to_string(),
|
|
});
|
|
// We don't set `errors = true` to avoid annoying type errors relative
|
|
// to the expanded macro type signature
|
|
}
|
|
|
|
if errors {
|
|
// This is not the right signature, but we can continue parsing.
|
|
return (sig.clone(), new_inputs, idents, true);
|
|
}
|
|
|
|
let unsafe_activities = x
|
|
.input_activity
|
|
.iter()
|
|
.any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
|
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
|
d_inputs.push(arg.clone());
|
|
match activity {
|
|
DiffActivity::Active => {
|
|
act_ret.push(arg.ty.clone());
|
|
// if width =/= 1, then push [arg.ty; width] to act_ret
|
|
}
|
|
DiffActivity::ActiveOnly => {
|
|
// We will add the active scalar to the return type.
|
|
// This is handled later.
|
|
}
|
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
|
for i in 0..x.width {
|
|
let mut shadow_arg = arg.clone();
|
|
// We += into the shadow in reverse mode.
|
|
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
ident.name
|
|
} else {
|
|
debug!("{:#?}", &shadow_arg.pat);
|
|
panic!("not an ident?");
|
|
};
|
|
let name: String = format!("d{}_{}", old_name, i);
|
|
new_inputs.push(name.clone());
|
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
|
shadow_arg.pat = P(ast::Pat {
|
|
id: ast::DUMMY_NODE_ID,
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
span: shadow_arg.pat.span,
|
|
tokens: shadow_arg.pat.tokens.clone(),
|
|
});
|
|
d_inputs.push(shadow_arg.clone());
|
|
}
|
|
}
|
|
DiffActivity::Dual | DiffActivity::DualOnly => {
|
|
for i in 0..x.width {
|
|
let mut shadow_arg = arg.clone();
|
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
ident.name
|
|
} else {
|
|
debug!("{:#?}", &shadow_arg.pat);
|
|
panic!("not an ident?");
|
|
};
|
|
let name: String = format!("b{}_{}", old_name, i);
|
|
new_inputs.push(name.clone());
|
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
|
shadow_arg.pat = P(ast::Pat {
|
|
id: ast::DUMMY_NODE_ID,
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
span: shadow_arg.pat.span,
|
|
tokens: shadow_arg.pat.tokens.clone(),
|
|
});
|
|
d_inputs.push(shadow_arg.clone());
|
|
}
|
|
}
|
|
DiffActivity::Const => {
|
|
// Nothing to do here.
|
|
}
|
|
DiffActivity::None | DiffActivity::FakeActivitySize => {
|
|
panic!("Should not happen");
|
|
}
|
|
}
|
|
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
|
idents.push(ident.clone());
|
|
} else {
|
|
panic!("not an ident?");
|
|
}
|
|
}
|
|
|
|
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
|
|
if active_only_ret {
|
|
assert!(x.mode.is_rev());
|
|
}
|
|
|
|
// If we return a scalar in the primal and the scalar is active,
|
|
// then add it as last arg to the inputs.
|
|
if x.mode.is_rev() {
|
|
match x.ret_activity {
|
|
DiffActivity::Active | DiffActivity::ActiveOnly => {
|
|
let ty = match d_decl.output {
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
FnRetTy::Default(span) => {
|
|
panic!("Did not expect Default ret ty: {:?}", span);
|
|
}
|
|
};
|
|
let name = "dret".to_string();
|
|
let ident = Ident::from_str_and_span(&name, ty.span);
|
|
let shadow_arg = ast::Param {
|
|
attrs: ThinVec::new(),
|
|
ty: ty.clone(),
|
|
pat: P(ast::Pat {
|
|
id: ast::DUMMY_NODE_ID,
|
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
|
span: ty.span,
|
|
tokens: None,
|
|
}),
|
|
id: ast::DUMMY_NODE_ID,
|
|
span: ty.span,
|
|
is_placeholder: false,
|
|
};
|
|
d_inputs.push(shadow_arg);
|
|
new_inputs.push(name);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
d_decl.inputs = d_inputs.into();
|
|
|
|
if x.mode.is_fwd() {
|
|
let ty = match d_decl.output {
|
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
|
FnRetTy::Default(span) => {
|
|
// We want to return std::hint::black_box(()).
|
|
let kind = TyKind::Tup(ThinVec::new());
|
|
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
|
d_decl.output = FnRetTy::Ty(ty.clone());
|
|
assert!(matches!(x.ret_activity, DiffActivity::None));
|
|
// this won't be used below, so any type would be fine.
|
|
ty
|
|
}
|
|
};
|
|
|
|
if let DiffActivity::Dual = x.ret_activity {
|
|
let kind = if x.width == 1 {
|
|
// Dual can only be used for f32/f64 ret.
|
|
// In that case we return now a tuple with two floats.
|
|
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
|
} else {
|
|
// We have to return [T; width+1], +1 for the primal return.
|
|
let anon_const = rustc_ast::AnonConst {
|
|
id: ast::DUMMY_NODE_ID,
|
|
value: ecx.expr_usize(span, 1 + x.width as usize),
|
|
};
|
|
TyKind::Array(ty.clone(), anon_const)
|
|
};
|
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
|
d_decl.output = FnRetTy::Ty(ty);
|
|
}
|
|
if let DiffActivity::DualOnly = x.ret_activity {
|
|
// No need to change the return type,
|
|
// we will just return the shadow in place of the primal return.
|
|
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
|
if x.width > 1 {
|
|
let anon_const = rustc_ast::AnonConst {
|
|
id: ast::DUMMY_NODE_ID,
|
|
value: ecx.expr_usize(span, x.width as usize),
|
|
};
|
|
let kind = TyKind::Array(ty.clone(), anon_const);
|
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
|
d_decl.output = FnRetTy::Ty(ty);
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we use ActiveOnly, drop the original return value.
|
|
d_decl.output =
|
|
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
|
|
|
|
trace!("act_ret: {:?}", act_ret);
|
|
|
|
// If we have an active input scalar, add it's gradient to the
|
|
// return type. This might require changing the return type to a
|
|
// tuple.
|
|
if act_ret.len() > 0 {
|
|
let ret_ty = match d_decl.output {
|
|
FnRetTy::Ty(ref ty) => {
|
|
if !active_only_ret {
|
|
act_ret.insert(0, ty.clone());
|
|
}
|
|
let kind = TyKind::Tup(act_ret);
|
|
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
|
|
}
|
|
FnRetTy::Default(span) => {
|
|
if act_ret.len() == 1 {
|
|
act_ret[0].clone()
|
|
} else {
|
|
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
|
|
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
|
|
}
|
|
}
|
|
};
|
|
d_decl.output = FnRetTy::Ty(ret_ty);
|
|
}
|
|
|
|
let mut d_header = sig.header.clone();
|
|
if unsafe_activities {
|
|
d_header.safety = rustc_ast::Safety::Unsafe(span);
|
|
}
|
|
let d_sig = FnSig { header: d_header, decl: d_decl, span };
|
|
trace!("Generated signature: {:?}", d_sig);
|
|
(d_sig, new_inputs, idents, false)
|
|
}
|
|
}
|
|
|
|
pub(crate) use llvm_enzyme::expand;
|