1
Fork 0

Auto merge of #129458 - EnzymeAD:enzyme-frontend, r=jieyouxu

Autodiff Upstreaming - enzyme frontend

This is an upstream PR for the `autodiff` rustc_builtin_macro that is part of the autodiff feature.

For the full implementation, see: https://github.com/rust-lang/rust/pull/129175

**Content:**
It contains a new `#[autodiff(<args>)]` rustc_builtin_macro, as well as a `#[rustc_autodiff]` builtin attribute.
The autodiff macro is applied on function `f` and will expand to a second function `df` (name given by user).
It will add a dummy body to `df` to make sure it type-checks. The body will later be replaced by enzyme on llvm-ir level,
we therefore don't really care about the content. Most of the changes (700 from 1.2k) are in `compiler/rustc_builtin_macros/src/autodiff.rs`, which expand the macro. Nothing except expansion is implemented for now.
I have a fallback implementation for relevant functions in case that rustc should be build without autodiff support. The default for now will be off, although we want to flip it later (once everything landed) to on for nightly. For the sake of CI, I have flipped the defaults, I'll revert this before merging.

**Dummy function Body:**
The first line is an `inline_asm` nop to make inlining less likely (I have additional checks to prevent this in the middle end of rustc. If `f` gets inlined too early, we can't pass it to enzyme and thus can't differentiate it.
If `df` gets inlined too early, the call site will just compute this dummy code instead of the derivatives, a correctness issue. The following black_box lines make sure that none of the input arguments is getting optimized away before we replace the body.

**Motivation:**
The user facing autodiff macro can verify the user input. Then I write it as args to the rustc_attribute, so from here on I can know that these values should be sensible. A rustc_attribute also turned out to be quite nice to attach this information to the corresponding function and carry it till the backend.
This is also just an experiment, I expect to adjust the user facing autodiff macro based on user feedback, to improve usability.

As a simple example of what this will do, we can see this expansion:
From:
```
#[autodiff(df, Reverse, Duplicated, Const, Active)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    unimplemented!()
}
```
to
```
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
    ::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
#[inline(never)]
pub fn df(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
    unsafe { asm!("NOP"); };
    ::core::hint::black_box(f1(x, y));
    ::core::hint::black_box((dx, dret));
    ::core::hint::black_box(f1(x, y))
}
```
I will add a few more tests once I figured out why rustc rebuilds every time I touch a test.

Tracking:

- https://github.com/rust-lang/rust/issues/124509

try-job: dist-x86_64-msvc
This commit is contained in:
bors 2024-10-15 01:30:01 +00:00
commit 785c83015c
32 changed files with 2128 additions and 1 deletions

View file

@ -0,0 +1,820 @@
//! This module contains the implementation of the `#[autodiff]` attribute.
//! Currently our linter isn't smart enough to see that each import is used in one of the two
//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
#[cfg(llvm_enzyme)]
mod llvm_enzyme {
use std::str::FromStr;
use std::string::String;
use rustc_ast::expand::autodiff_attrs::{
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
};
use rustc_ast::ptr::P;
use rustc_ast::token::{Token, TokenKind};
use rustc_ast::tokenstream::*;
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
PatKind, TyKind,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{Ident, kw, sym};
use rustc_span::{Span, Symbol};
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, trace};
use crate::errors;
// If we have a default `()` return type or explicitley `()` return type,
// then we often can skip doing some work.
fn has_ret(ty: &FnRetTy) -> bool {
match ty {
FnRetTy::Ty(ty) => !ty.kind.is_unit(),
FnRetTy::Default(_) => false,
}
}
fn first_ident(x: &MetaItemInner) -> rustc_span::symbol::Ident {
let segments = &x.meta_item().unwrap().path.segments;
assert!(segments.len() == 1);
segments[0].ident
}
fn name(x: &MetaItemInner) -> String {
first_ident(x).name.to_string()
}
pub(crate) fn from_ast(
ecx: &mut ExtCtxt<'_>,
meta_item: &ThinVec<MetaItemInner>,
has_ret: bool,
) -> AutoDiffAttrs {
let dcx = ecx.sess.dcx();
let mode = name(&meta_item[1]);
let Ok(mode) = DiffMode::from_str(&mode) else {
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
return AutoDiffAttrs::error();
};
let mut activities: Vec<DiffActivity> = vec![];
let mut errors = false;
for x in &meta_item[2..] {
let activity_str = name(&x);
let res = DiffActivity::from_str(&activity_str);
match res {
Ok(x) => activities.push(x),
Err(_) => {
dcx.emit_err(errors::AutoDiffUnknownActivity {
span: x.span(),
act: activity_str,
});
errors = true;
}
};
}
if errors {
return AutoDiffAttrs::error();
}
// If a return type exist, we need to split the last activity,
// otherwise we return None as placeholder.
let (ret_activity, input_activity) = if has_ret {
let Some((last, rest)) = activities.split_last() else {
unreachable!(
"should not be reachable because we counted the number of activities previously"
);
};
(last, rest)
} else {
(&DiffActivity::None, activities.as_slice())
};
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
}
/// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
/// should just prevent early inlining and optimizations which alter the function signature.
/// The exact signature of the generated function depends on the configuration provided by the
/// user, but here is an example:
///
/// ```
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
/// fn sin(x: &Box<f32>) -> f32 {
/// f32::sin(**x)
/// }
/// ```
/// which becomes expanded to:
/// ```
/// #[rustc_autodiff]
/// #[inline(never)]
/// fn sin(x: &Box<f32>) -> f32 {
/// f32::sin(**x)
/// }
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
/// #[inline(never)]
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
/// unsafe {
/// asm!("NOP");
/// };
/// ::core::hint::black_box(sin(x));
/// ::core::hint::black_box((dx, dret));
/// ::core::hint::black_box(sin(x))
/// }
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
/// in CI.
pub(crate) fn expand(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
mut item: Annotatable,
) -> Vec<Annotatable> {
let dcx = ecx.sess.dcx();
// first get the annotable item:
let (sig, is_impl): (FnSig, bool) = match &item {
Annotatable::Item(ref iitem) => {
let sig = match &iitem.kind {
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
(sig.clone(), false)
}
Annotatable::AssocItem(ref assoc_item, _) => {
let sig = match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
(sig.clone(), true)
}
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
ast::MetaItemKind::List(ref vec) => vec.clone(),
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
let has_ret = has_ret(&sig.decl.output);
let sig_span = ecx.with_call_site_ctxt(sig.span);
let (vis, primal) = match &item {
Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
Annotatable::AssocItem(ref assoc_item, _) => {
(assoc_item.vis.clone(), assoc_item.ident.clone())
}
_ => {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
}
};
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
let comma: Token = Token::new(TokenKind::Comma, Span::default());
let mut ts: Vec<TokenTree> = vec![];
if meta_item_vec.len() < 2 {
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
// input and output args.
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
return vec![item];
} else {
for t in meta_item_vec.clone()[1..].iter() {
let val = first_ident(t);
let t = Token::from_ast_ident(val);
ts.push(TokenTree::Token(t, Spacing::Joint));
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
}
}
if !has_ret {
// We don't want users to provide a return activity if the function doesn't return anything.
// For simplicity, we just add a dummy token to the end of the list.
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
ts.push(TokenTree::Token(t, Spacing::Joint));
}
let ts: TokenStream = TokenStream::from_iter(ts);
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
if !x.is_active() {
// We encountered an error, so we return the original item.
// This allows us to potentially parse other attributes.
return vec![item];
}
let span = ecx.with_def_site_ctxt(expand_span);
let n_active: u32 = x
.input_activity
.iter()
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let new_decl_span = d_sig.span;
let d_body = gen_enzyme_body(
ecx,
&x,
n_active,
&sig,
&d_sig,
primal,
&new_args,
span,
sig_span,
new_decl_span,
idents,
errored,
);
let d_ident = first_ident(&meta_item_vec[0]);
// The first element of it is the name of the function to be generated
let asdf = Box::new(ast::Fn {
defaultness: ast::Defaultness::Final,
sig: d_sig,
generics: Generics::default(),
body: Some(d_body),
});
let mut rustc_ad_attr =
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
Token::new(TokenKind::Ident(sym::never, false.into()), span),
Spacing::Joint,
)];
let never_arg = ast::DelimArgs {
dspan: ast::tokenstream::DelimSpan::from_single(span),
delim: ast::token::Delimiter::Parenthesis,
tokens: ast::tokenstream::TokenStream::from_iter(ts2),
};
let inline_item = ast::AttrItem {
unsafety: ast::Safety::Default,
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
args: ast::AttrArgs::Delimited(never_arg),
tokens: None,
};
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(inline_never_attr),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
iitem.attrs.push(attr.clone());
}
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
iitem.attrs.push(inline_never.clone());
}
Annotatable::Item(iitem.clone())
}
Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
assoc_item.attrs.push(attr.clone());
}
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
assoc_item.attrs.push(inline_never.clone());
}
Annotatable::AssocItem(assoc_item.clone(), i)
}
_ => {
unreachable!("annotatable kind checked previously")
}
};
// Now update for d_fn
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
dspan: DelimSpan::dummy(),
delim: rustc_ast::token::Delimiter::Parenthesis,
tokens: ts,
});
let d_attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let d_annotatable = if is_impl {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
attrs: thin_vec![d_attr.clone(), inline_never],
id: ast::DUMMY_NODE_ID,
span,
vis,
ident: d_ident,
kind: assoc_item,
tokens: None,
});
Annotatable::AssocItem(d_fn, Impl)
} else {
let mut d_fn = ecx.item(
span,
d_ident,
thin_vec![d_attr.clone(), inline_never],
ItemKind::Fn(asdf),
);
d_fn.vis = vis;
Annotatable::Item(d_fn)
};
return vec![orig_annotatable, d_annotatable];
}
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
// mutable references or ptrs, because Enzyme will write into them.
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
let mut ty = ty.clone();
match ty.kind {
TyKind::Ptr(ref mut mut_ty) => {
mut_ty.mutbl = ast::Mutability::Mut;
}
TyKind::Ref(_, ref mut mut_ty) => {
mut_ty.mutbl = ast::Mutability::Mut;
}
_ => {
panic!("unsupported type: {:?}", ty);
}
}
ty
}
/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we build something that should pass. We also add a inline_asm
/// line, as one more barrier for rustc to prevent inlining of this function.
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
/// this function (which should never happen, since it is only a placeholder).
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
new_decl_span: Span,
idents: Vec<Ident>,
errored: bool,
) -> P<ast::Block> {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
template_strs: Box::new([]),
operands: vec![],
clobber_abis: vec![],
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
line_spans: vec![],
};
let noop_expr = ecx.expr_asm(span, P(noop));
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
let unsf_block = ast::Block {
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
id: ast::DUMMY_NODE_ID,
tokens: None,
rules: unsf,
span,
could_be_bare_literal: false,
};
let unsf_expr = ecx.expr_block(P(unsf_block));
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
let primal_call = gen_primal_call(ecx, span, primal, idents);
let black_box_primal_call =
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![
primal_call.clone()
]);
let tup_args = new_names
.iter()
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
.collect();
let black_box_remaining_args =
ecx.expr_call(sig_span, blackbox_call_expr.clone(), thin_vec![
ecx.expr_tuple(sig_span, tup_args)
]);
let mut body = ecx.block(span, ThinVec::new());
body.stmts.push(ecx.stmt_semi(unsf_expr));
// This uses primal args which won't be available if we errored before
if !errored {
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
}
// having an active-only return means we'll drop the original return type.
// So that can be treated identical to not having one in the first place.
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
return body;
}
if !primal_ret && n_active == 1 {
// Again no tuple return, so return default float val.
let ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
let arg = ty.kind.is_simple_path().unwrap();
let sl: Vec<Symbol> = vec![arg, kw::Default];
let tmp = ecx.def_site_path(&sl);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
body.stmts.push(ecx.stmt_expr(default_call_expr));
return body;
}
let mut exprs = ThinVec::<P<ast::Expr>>::new();
if primal_ret {
// We have both primal ret and active floats.
// primal ret is first, by construction.
exprs.push(primal_call.clone());
}
// Now construct default placeholder for each active float.
// Is there something nicer than f32::default() and f64::default()?
let d_ret_ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
let mut d_ret_ty = match d_ret_ty.kind.clone() {
TyKind::Tup(ref tys) => tys.clone(),
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
if let [segment] = &segments[..]
&& segment.args.is_none()
{
let id = vec![segments[0].ident];
let kind = TyKind::Path(None, ecx.path(span, id));
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
thin_vec![ty]
} else {
panic!("Expected tuple or simple path return type");
}
}
_ => {
// We messed up construction of d_sig
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
}
};
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
assert!(d_ret_ty.len() == 2);
// both should be identical, by construction
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
assert!(arg == arg2);
let sl: Vec<Symbol> = vec![arg, kw::Default];
let tmp = ecx.def_site_path(&sl);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs.push(default_call_expr);
} else if x.mode.is_rev() {
if primal_ret {
// We have extra handling above for the primal ret
d_ret_ty = d_ret_ty[1..].to_vec().into();
}
for arg in d_ret_ty.iter() {
let arg = arg.kind.is_simple_path().unwrap();
let sl: Vec<Symbol> = vec![arg, kw::Default];
let tmp = ecx.def_site_path(&sl);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr =
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs.push(default_call_expr);
}
}
let ret: P<ast::Expr>;
match &exprs[..] {
[] => {
assert!(!has_ret(&d_sig.decl.output));
// We don't have to match the return type.
return body;
}
[arg] => {
ret = ecx
.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![arg.clone()]);
}
args => {
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
ret =
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
}
}
assert!(has_ret(&d_sig.decl.output));
body.stmts.push(ecx.stmt_expr(ret));
body
}
fn gen_primal_call(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: Vec<Ident>,
) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
let args: ThinVec<_> =
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let self_expr = ecx.expr_self(span);
ecx.expr_method_call(span, self_expr, primal, args.clone())
} else {
let args: ThinVec<_> =
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
ecx.expr_call(span, primal_call_expr, args)
}
}
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
// zero-initialized by Enzyme).
// Each argument of the primal function (and the return type if existing) must be annotated with an
// activity.
//
// Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
// both), we emit an error and return the original signature. This allows us to continue parsing.
fn gen_enzyme_decl(
ecx: &ExtCtxt<'_>,
sig: &ast::FnSig,
x: &AutoDiffAttrs,
span: Span,
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
let dcx = ecx.sess.dcx();
let has_ret = has_ret(&sig.decl.output);
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
if sig_args != num_activities {
dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
span,
expected: sig_args,
found: num_activities,
});
// This is not the right signature, but we can continue parsing.
return (sig.clone(), vec![], vec![], true);
}
assert!(sig.decl.inputs.len() == x.input_activity.len());
assert!(has_ret == x.has_ret_activity());
let mut d_decl = sig.decl.clone();
let mut d_inputs = Vec::new();
let mut new_inputs = Vec::new();
let mut idents = Vec::new();
let mut act_ret = ThinVec::new();
// We have two loops, a first one just to check the activities and types and possibly report
// multiple errors in one compilation session.
let mut errors = false;
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
if !valid_input_activity(x.mode, *activity) {
dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
span,
mode: x.mode.to_string(),
act: activity.to_string(),
});
errors = true;
}
if !valid_ty_for_activity(&arg.ty, *activity) {
dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
span: arg.ty.span,
act: activity.to_string(),
});
errors = true;
}
}
if errors {
// This is not the right signature, but we can continue parsing.
return (sig.clone(), new_inputs, idents, true);
}
let unsafe_activities = x
.input_activity
.iter()
.any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
d_inputs.push(arg.clone());
match activity {
DiffActivity::Active => {
act_ret.push(arg.ty.clone());
}
DiffActivity::ActiveOnly => {
// We will add the active scalar to the return type.
// This is handled later.
}
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
let mut shadow_arg = arg.clone();
// We += into the shadow in reverse mode.
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
} else {
debug!("{:#?}", &shadow_arg.pat);
panic!("not an ident?");
};
let name: String = format!("d{}", old_name);
new_inputs.push(name.clone());
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
shadow_arg.pat = P(ast::Pat {
id: ast::DUMMY_NODE_ID,
kind: PatKind::Ident(BindingMode::NONE, ident, None),
span: shadow_arg.pat.span,
tokens: shadow_arg.pat.tokens.clone(),
});
d_inputs.push(shadow_arg);
}
DiffActivity::Dual | DiffActivity::DualOnly => {
let mut shadow_arg = arg.clone();
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
} else {
debug!("{:#?}", &shadow_arg.pat);
panic!("not an ident?");
};
let name: String = format!("b{}", old_name);
new_inputs.push(name.clone());
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
shadow_arg.pat = P(ast::Pat {
id: ast::DUMMY_NODE_ID,
kind: PatKind::Ident(BindingMode::NONE, ident, None),
span: shadow_arg.pat.span,
tokens: shadow_arg.pat.tokens.clone(),
});
d_inputs.push(shadow_arg);
}
DiffActivity::Const => {
// Nothing to do here.
}
DiffActivity::None | DiffActivity::FakeActivitySize => {
panic!("Should not happen");
}
}
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
idents.push(ident.clone());
} else {
panic!("not an ident?");
}
}
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
if active_only_ret {
assert!(x.mode.is_rev());
}
// If we return a scalar in the primal and the scalar is active,
// then add it as last arg to the inputs.
if x.mode.is_rev() {
match x.ret_activity {
DiffActivity::Active | DiffActivity::ActiveOnly => {
let ty = match d_decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
let name = "dret".to_string();
let ident = Ident::from_str_and_span(&name, ty.span);
let shadow_arg = ast::Param {
attrs: ThinVec::new(),
ty: ty.clone(),
pat: P(ast::Pat {
id: ast::DUMMY_NODE_ID,
kind: PatKind::Ident(BindingMode::NONE, ident, None),
span: ty.span,
tokens: None,
}),
id: ast::DUMMY_NODE_ID,
span: ty.span,
is_placeholder: false,
};
d_inputs.push(shadow_arg);
new_inputs.push(name);
}
_ => {}
}
}
d_decl.inputs = d_inputs.into();
if x.mode.is_fwd() {
if let DiffActivity::Dual = x.ret_activity {
let ty = match d_decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
// Dual can only be used for f32/f64 ret.
// In that case we return now a tuple with two floats.
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
d_decl.output = FnRetTy::Ty(ty);
}
if let DiffActivity::DualOnly = x.ret_activity {
// No need to change the return type,
// we will just return the shadow in place
// of the primal return.
}
}
// If we use ActiveOnly, drop the original return value.
d_decl.output =
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
trace!("act_ret: {:?}", act_ret);
// If we have an active input scalar, add it's gradient to the
// return type. This might require changing the return type to a
// tuple.
if act_ret.len() > 0 {
let ret_ty = match d_decl.output {
FnRetTy::Ty(ref ty) => {
if !active_only_ret {
act_ret.insert(0, ty.clone());
}
let kind = TyKind::Tup(act_ret);
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
}
FnRetTy::Default(span) => {
if act_ret.len() == 1 {
act_ret[0].clone()
} else {
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
}
}
};
d_decl.output = FnRetTy::Ty(ret_ty);
}
let mut d_header = sig.header.clone();
if unsafe_activities {
d_header.safety = rustc_ast::Safety::Unsafe(span);
}
let d_sig = FnSig { header: d_header, decl: d_decl, span };
trace!("Generated signature: {:?}", d_sig);
(d_sig, new_inputs, idents, false)
}
}
#[cfg(not(llvm_enzyme))]
mod ad_fallback {
use rustc_ast::ast;
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::Span;
use crate::errors;
pub(crate) fn expand(
ecx: &mut ExtCtxt<'_>,
_expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
return vec![item];
}
}
#[cfg(not(llvm_enzyme))]
pub(crate) use ad_fallback::expand;
#[cfg(llvm_enzyme)]
pub(crate) use llvm_enzyme::expand;

View file

@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics {
pub(crate) span: Span,
}
#[cfg(llvm_enzyme)]
pub(crate) use autodiff::*;
#[cfg(llvm_enzyme)]
mod autodiff {
use super::*;
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_missing_config)]
pub(crate) struct AutoDiffMissingConfig {
#[primary_span]
pub(crate) span: Span,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_unknown_activity)]
pub(crate) struct AutoDiffUnknownActivity {
#[primary_span]
pub(crate) span: Span,
pub(crate) act: String,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_ty_activity)]
pub(crate) struct AutoDiffInvalidTypeForActivity {
#[primary_span]
pub(crate) span: Span,
pub(crate) act: String,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_number_activities)]
pub(crate) struct AutoDiffInvalidNumberActivities {
#[primary_span]
pub(crate) span: Span,
pub(crate) expected: usize,
pub(crate) found: usize,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_mode_activity)]
pub(crate) struct AutoDiffInvalidApplicationModeAct {
#[primary_span]
pub(crate) span: Span,
pub(crate) mode: String,
pub(crate) act: String,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_mode)]
pub(crate) struct AutoDiffInvalidMode {
#[primary_span]
pub(crate) span: Span,
pub(crate) mode: String,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff)]
pub(crate) struct AutoDiffInvalidApplication {
#[primary_span]
pub(crate) span: Span,
}
}
#[cfg(not(llvm_enzyme))]
pub(crate) use ad_fallback::*;
#[cfg(not(llvm_enzyme))]
mod ad_fallback {
use super::*;
#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_not_build)]
pub(crate) struct AutoDiffSupportNotBuild {
#[primary_span]
pub(crate) span: Span,
}
}
#[derive(Diagnostic)]
#[diag(builtin_macros_concat_bytes_invalid)]
pub(crate) struct ConcatBytesInvalid {

View file

@ -5,6 +5,7 @@
#![allow(internal_features)]
#![allow(rustc::diagnostic_outside_of_impl)]
#![allow(rustc::untranslatable_diagnostic)]
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
#![doc(rust_logo)]
#![feature(assert_matches)]
@ -29,6 +30,7 @@ use crate::deriving::*;
mod alloc_error_handler;
mod assert;
mod autodiff;
mod cfg;
mod cfg_accessible;
mod cfg_eval;
@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
register_attr! {
alloc_error_handler: alloc_error_handler::expand,
autodiff: autodiff::expand,
bench: test::expand_bench,
cfg_accessible: cfg_accessible::Expander,
cfg_eval: cfg_eval::expand,