add the autodiff batch mode frontend
This commit is contained in:
parent
aa8f0fd716
commit
087ffd73bf
5 changed files with 237 additions and 128 deletions
|
@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
|
||||||
/// e.g. in the [JAX
|
/// e.g. in the [JAX
|
||||||
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
||||||
pub mode: DiffMode,
|
pub mode: DiffMode,
|
||||||
|
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
|
||||||
|
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
|
||||||
|
/// - Calling the function 50 times with a batch size of 2
|
||||||
|
/// - Calling the function 25 times with a batch size of 4,
|
||||||
|
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
|
||||||
|
/// cache locality, better re-usal of primal values, and other optimizations.
|
||||||
|
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
|
||||||
|
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
|
||||||
|
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
|
||||||
|
/// experiments for now and focus on documenting the implications of a large width.
|
||||||
|
pub width: u32,
|
||||||
pub ret_activity: DiffActivity,
|
pub ret_activity: DiffActivity,
|
||||||
pub input_activity: Vec<DiffActivity>,
|
pub input_activity: Vec<DiffActivity>,
|
||||||
}
|
}
|
||||||
|
@ -222,6 +233,7 @@ impl AutoDiffAttrs {
|
||||||
pub const fn error() -> Self {
|
pub const fn error() -> Self {
|
||||||
AutoDiffAttrs {
|
AutoDiffAttrs {
|
||||||
mode: DiffMode::Error,
|
mode: DiffMode::Error,
|
||||||
|
width: 0,
|
||||||
ret_activity: DiffActivity::None,
|
ret_activity: DiffActivity::None,
|
||||||
input_activity: Vec::new(),
|
input_activity: Vec::new(),
|
||||||
}
|
}
|
||||||
|
@ -229,6 +241,7 @@ impl AutoDiffAttrs {
|
||||||
pub fn source() -> Self {
|
pub fn source() -> Self {
|
||||||
AutoDiffAttrs {
|
AutoDiffAttrs {
|
||||||
mode: DiffMode::Source,
|
mode: DiffMode::Source,
|
||||||
|
width: 0,
|
||||||
ret_activity: DiffActivity::None,
|
ret_activity: DiffActivity::None,
|
||||||
input_activity: Vec::new(),
|
input_activity: Vec::new(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
|
||||||
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
||||||
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
||||||
|
|
||||||
|
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
|
||||||
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
||||||
.label = not applicable here
|
.label = not applicable here
|
||||||
.label2 = not a `struct`, `enum` or `union`
|
.label2 = not a `struct`, `enum` or `union`
|
||||||
|
|
|
@ -12,12 +12,12 @@ mod llvm_enzyme {
|
||||||
valid_ty_for_activity,
|
valid_ty_for_activity,
|
||||||
};
|
};
|
||||||
use rustc_ast::ptr::P;
|
use rustc_ast::ptr::P;
|
||||||
use rustc_ast::token::{Token, TokenKind};
|
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
|
||||||
use rustc_ast::tokenstream::*;
|
use rustc_ast::tokenstream::*;
|
||||||
use rustc_ast::visit::AssocCtxt::*;
|
use rustc_ast::visit::AssocCtxt::*;
|
||||||
use rustc_ast::{
|
use rustc_ast::{
|
||||||
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
|
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
||||||
PatKind, TyKind,
|
MetaItemInner, PatKind, QSelf, TyKind,
|
||||||
};
|
};
|
||||||
use rustc_expand::base::{Annotatable, ExtCtxt};
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||||
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
||||||
|
@ -45,6 +45,16 @@ mod llvm_enzyme {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
||||||
|
if let Some(l) = x.lit() {
|
||||||
|
match l.kind {
|
||||||
|
ast::LitKind::Int(val, _) => {
|
||||||
|
// get an Ident from a lit
|
||||||
|
return rustc_span::Ident::from_str(val.get().to_string().as_str());
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let segments = &x.meta_item().unwrap().path.segments;
|
let segments = &x.meta_item().unwrap().path.segments;
|
||||||
assert!(segments.len() == 1);
|
assert!(segments.len() == 1);
|
||||||
segments[0].ident
|
segments[0].ident
|
||||||
|
@ -54,6 +64,14 @@ mod llvm_enzyme {
|
||||||
first_ident(x).name.to_string()
|
first_ident(x).name.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn width(x: &MetaItemInner) -> Option<u128> {
|
||||||
|
let lit = x.lit()?;
|
||||||
|
match lit.kind {
|
||||||
|
ast::LitKind::Int(x, _) => Some(x.get()),
|
||||||
|
_ => return None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn from_ast(
|
pub(crate) fn from_ast(
|
||||||
ecx: &mut ExtCtxt<'_>,
|
ecx: &mut ExtCtxt<'_>,
|
||||||
meta_item: &ThinVec<MetaItemInner>,
|
meta_item: &ThinVec<MetaItemInner>,
|
||||||
|
@ -65,9 +83,32 @@ mod llvm_enzyme {
|
||||||
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||||
return AutoDiffAttrs::error();
|
return AutoDiffAttrs::error();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
||||||
|
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
||||||
|
let mut first_activity = 2;
|
||||||
|
|
||||||
|
let width = if let [_, _, x, ..] = &meta_item[..]
|
||||||
|
&& let Some(x) = width(x)
|
||||||
|
{
|
||||||
|
first_activity = 3;
|
||||||
|
match x.try_into() {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(_) => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
||||||
|
span: meta_item[2].span(),
|
||||||
|
width: x,
|
||||||
|
});
|
||||||
|
return AutoDiffAttrs::error();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
let mut activities: Vec<DiffActivity> = vec![];
|
let mut activities: Vec<DiffActivity> = vec![];
|
||||||
let mut errors = false;
|
let mut errors = false;
|
||||||
for x in &meta_item[2..] {
|
for x in &meta_item[first_activity..] {
|
||||||
let activity_str = name(&x);
|
let activity_str = name(&x);
|
||||||
let res = DiffActivity::from_str(&activity_str);
|
let res = DiffActivity::from_str(&activity_str);
|
||||||
match res {
|
match res {
|
||||||
|
@ -98,7 +139,20 @@ mod llvm_enzyme {
|
||||||
(&DiffActivity::None, activities.as_slice())
|
(&DiffActivity::None, activities.as_slice())
|
||||||
};
|
};
|
||||||
|
|
||||||
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
|
AutoDiffAttrs {
|
||||||
|
mode,
|
||||||
|
width,
|
||||||
|
ret_activity: *ret_activity,
|
||||||
|
input_activity: input_activity.to_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
|
||||||
|
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||||
|
let val = first_ident(t);
|
||||||
|
let t = Token::from_ast_ident(val);
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// We expand the autodiff macro to generate a new placeholder function which passes
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||||
|
@ -197,27 +251,49 @@ mod llvm_enzyme {
|
||||||
|
|
||||||
// create TokenStream from vec elemtents:
|
// create TokenStream from vec elemtents:
|
||||||
// meta_item doesn't have a .tokens field
|
// meta_item doesn't have a .tokens field
|
||||||
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
|
||||||
let mut ts: Vec<TokenTree> = vec![];
|
let mut ts: Vec<TokenTree> = vec![];
|
||||||
if meta_item_vec.len() < 2 {
|
if meta_item_vec.len() < 2 {
|
||||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||||
// input and output args.
|
// input and output args.
|
||||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||||
return vec![item];
|
return vec![item];
|
||||||
} else {
|
|
||||||
for t in meta_item_vec.clone()[1..].iter() {
|
|
||||||
let val = first_ident(t);
|
|
||||||
let t = Token::from_ast_ident(val);
|
|
||||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
|
||||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
||||||
|
|
||||||
|
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||||
|
// If it is not given, we default to 1 (scalar mode).
|
||||||
|
let start_position;
|
||||||
|
let kind: LitKind = LitKind::Integer;
|
||||||
|
let symbol;
|
||||||
|
if meta_item_vec.len() >= 3
|
||||||
|
&& let Some(width) = width(&meta_item_vec[2])
|
||||||
|
{
|
||||||
|
start_position = 3;
|
||||||
|
symbol = Symbol::intern(&width.to_string());
|
||||||
|
} else {
|
||||||
|
start_position = 2;
|
||||||
|
symbol = sym::integer(1);
|
||||||
|
}
|
||||||
|
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||||
|
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||||
|
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||||
|
|
||||||
|
for t in meta_item_vec.clone()[start_position..].iter() {
|
||||||
|
meta_item_inner_to_ts(t, &mut ts);
|
||||||
|
}
|
||||||
|
|
||||||
if !has_ret {
|
if !has_ret {
|
||||||
// We don't want users to provide a return activity if the function doesn't return anything.
|
// We don't want users to provide a return activity if the function doesn't return anything.
|
||||||
// For simplicity, we just add a dummy token to the end of the list.
|
// For simplicity, we just add a dummy token to the end of the list.
|
||||||
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
||||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
ts.push(TokenTree::Token(comma, Spacing::Alone));
|
||||||
}
|
}
|
||||||
|
// We remove the last, trailing comma.
|
||||||
|
ts.pop();
|
||||||
let ts: TokenStream = TokenStream::from_iter(ts);
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||||
|
|
||||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||||
|
@ -475,6 +551,8 @@ mod llvm_enzyme {
|
||||||
return body;
|
return body;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Everything from here onwards just tries to fullfil the return type. Fun!
|
||||||
|
|
||||||
// having an active-only return means we'll drop the original return type.
|
// having an active-only return means we'll drop the original return type.
|
||||||
// So that can be treated identical to not having one in the first place.
|
// So that can be treated identical to not having one in the first place.
|
||||||
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
||||||
|
@ -502,86 +580,65 @@ mod llvm_enzyme {
|
||||||
return body;
|
return body;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut exprs = ThinVec::<P<ast::Expr>>::new();
|
let mut exprs: P<ast::Expr> = primal_call.clone();
|
||||||
if primal_ret {
|
|
||||||
// We have both primal ret and active floats.
|
|
||||||
// primal ret is first, by construction.
|
|
||||||
exprs.push(primal_call);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now construct default placeholder for each active float.
|
|
||||||
// Is there something nicer than f32::default() and f64::default()?
|
|
||||||
let d_ret_ty = match d_sig.decl.output {
|
let d_ret_ty = match d_sig.decl.output {
|
||||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
FnRetTy::Default(span) => {
|
FnRetTy::Default(span) => {
|
||||||
panic!("Did not expect Default ret ty: {:?}", span);
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut d_ret_ty = match d_ret_ty.kind.clone() {
|
|
||||||
TyKind::Tup(ref tys) => tys.clone(),
|
|
||||||
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
|
|
||||||
if let [segment] = &segments[..]
|
|
||||||
&& segment.args.is_none()
|
|
||||||
{
|
|
||||||
let id = vec![segments[0].ident];
|
|
||||||
let kind = TyKind::Path(None, ecx.path(span, id));
|
|
||||||
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
|
||||||
thin_vec![ty]
|
|
||||||
} else {
|
|
||||||
panic!("Expected tuple or simple path return type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// We messed up construction of d_sig
|
|
||||||
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
|
if x.mode.is_fwd() {
|
||||||
assert!(d_ret_ty.len() == 2);
|
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
||||||
// both should be identical, by construction
|
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
||||||
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
|
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
|
||||||
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
|
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
|
||||||
assert!(arg == arg2);
|
if x.ret_activity == DiffActivity::Const {
|
||||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
// Here we call the primal function, since our dummy function has the same return
|
||||||
let tmp = ecx.def_site_path(&sl);
|
// type due to the Const return activity.
|
||||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||||
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
} else {
|
||||||
exprs.push(default_call_expr);
|
let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
|
||||||
} else if x.mode.is_rev() {
|
let y =
|
||||||
if primal_ret {
|
ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
|
||||||
// We have extra handling above for the primal ret
|
let default_call_expr = ecx.expr(span, y);
|
||||||
d_ret_ty = d_ret_ty[1..].to_vec().into();
|
|
||||||
}
|
|
||||||
|
|
||||||
for arg in d_ret_ty.iter() {
|
|
||||||
let arg = arg.kind.is_simple_path().unwrap();
|
|
||||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
|
||||||
let tmp = ecx.def_site_path(&sl);
|
|
||||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
|
||||||
let default_call_expr =
|
let default_call_expr =
|
||||||
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
exprs.push(default_call_expr);
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
|
||||||
}
|
}
|
||||||
|
} else if x.mode.is_rev() {
|
||||||
|
if x.width == 1 {
|
||||||
|
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
|
||||||
|
match d_ret_ty.kind {
|
||||||
|
TyKind::Tup(ref args) => {
|
||||||
|
// We have a tuple return type. We need to create a tuple of the same size
|
||||||
|
// and fill it with default values.
|
||||||
|
let mut exprs2 = thin_vec![exprs];
|
||||||
|
for arg in args.iter().skip(1) {
|
||||||
|
let arg = arg.kind.is_simple_path().unwrap();
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr =
|
||||||
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
exprs2.push(default_call_expr);
|
||||||
|
}
|
||||||
|
exprs = ecx.expr_tuple(new_decl_span, exprs2);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Interestingly, even the `-> ArbitraryType` case
|
||||||
|
// ends up getting matched and handled correctly above,
|
||||||
|
// so we don't have to handle any other case for now.
|
||||||
|
panic!("Unsupported return type: {:?}", d_ret_ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||||
|
} else {
|
||||||
|
unreachable!("Unsupported mode: {:?}", x.mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
let ret: P<ast::Expr>;
|
body.stmts.push(ecx.stmt_expr(exprs));
|
||||||
match &exprs[..] {
|
|
||||||
[] => {
|
|
||||||
assert!(!has_ret(&d_sig.decl.output));
|
|
||||||
// We don't have to match the return type.
|
|
||||||
return body;
|
|
||||||
}
|
|
||||||
[arg] => {
|
|
||||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
|
|
||||||
}
|
|
||||||
args => {
|
|
||||||
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
|
|
||||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert!(has_ret(&d_sig.decl.output));
|
|
||||||
body.stmts.push(ecx.stmt_expr(ret));
|
|
||||||
|
|
||||||
body
|
body
|
||||||
}
|
}
|
||||||
|
@ -689,50 +746,55 @@ mod llvm_enzyme {
|
||||||
match activity {
|
match activity {
|
||||||
DiffActivity::Active => {
|
DiffActivity::Active => {
|
||||||
act_ret.push(arg.ty.clone());
|
act_ret.push(arg.ty.clone());
|
||||||
|
// if width =/= 1, then push [arg.ty; width] to act_ret
|
||||||
}
|
}
|
||||||
DiffActivity::ActiveOnly => {
|
DiffActivity::ActiveOnly => {
|
||||||
// We will add the active scalar to the return type.
|
// We will add the active scalar to the return type.
|
||||||
// This is handled later.
|
// This is handled later.
|
||||||
}
|
}
|
||||||
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
||||||
let mut shadow_arg = arg.clone();
|
for i in 0..x.width {
|
||||||
// We += into the shadow in reverse mode.
|
let mut shadow_arg = arg.clone();
|
||||||
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
// We += into the shadow in reverse mode.
|
||||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||||
ident.name
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
} else {
|
ident.name
|
||||||
debug!("{:#?}", &shadow_arg.pat);
|
} else {
|
||||||
panic!("not an ident?");
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
};
|
panic!("not an ident?");
|
||||||
let name: String = format!("d{}", old_name);
|
};
|
||||||
new_inputs.push(name.clone());
|
let name: String = format!("d{}_{}", old_name, i);
|
||||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
new_inputs.push(name.clone());
|
||||||
shadow_arg.pat = P(ast::Pat {
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
id: ast::DUMMY_NODE_ID,
|
shadow_arg.pat = P(ast::Pat {
|
||||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
id: ast::DUMMY_NODE_ID,
|
||||||
span: shadow_arg.pat.span,
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
tokens: shadow_arg.pat.tokens.clone(),
|
span: shadow_arg.pat.span,
|
||||||
});
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
d_inputs.push(shadow_arg);
|
});
|
||||||
|
d_inputs.push(shadow_arg.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DiffActivity::Dual | DiffActivity::DualOnly => {
|
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||||
let mut shadow_arg = arg.clone();
|
for i in 0..x.width {
|
||||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
let mut shadow_arg = arg.clone();
|
||||||
ident.name
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
} else {
|
ident.name
|
||||||
debug!("{:#?}", &shadow_arg.pat);
|
} else {
|
||||||
panic!("not an ident?");
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
};
|
panic!("not an ident?");
|
||||||
let name: String = format!("b{}", old_name);
|
};
|
||||||
new_inputs.push(name.clone());
|
let name: String = format!("b{}_{}", old_name, i);
|
||||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
new_inputs.push(name.clone());
|
||||||
shadow_arg.pat = P(ast::Pat {
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
id: ast::DUMMY_NODE_ID,
|
shadow_arg.pat = P(ast::Pat {
|
||||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
id: ast::DUMMY_NODE_ID,
|
||||||
span: shadow_arg.pat.span,
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
tokens: shadow_arg.pat.tokens.clone(),
|
span: shadow_arg.pat.span,
|
||||||
});
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
d_inputs.push(shadow_arg);
|
});
|
||||||
|
d_inputs.push(shadow_arg.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DiffActivity::Const => {
|
DiffActivity::Const => {
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
|
@ -788,23 +850,48 @@ mod llvm_enzyme {
|
||||||
d_decl.inputs = d_inputs.into();
|
d_decl.inputs = d_inputs.into();
|
||||||
|
|
||||||
if x.mode.is_fwd() {
|
if x.mode.is_fwd() {
|
||||||
|
let ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
// We want to return std::hint::black_box(()).
|
||||||
|
let kind = TyKind::Tup(ThinVec::new());
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||||
|
d_decl.output = FnRetTy::Ty(ty.clone());
|
||||||
|
assert!(matches!(x.ret_activity, DiffActivity::None));
|
||||||
|
// this won't be used below, so any type would be fine.
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if let DiffActivity::Dual = x.ret_activity {
|
if let DiffActivity::Dual = x.ret_activity {
|
||||||
let ty = match d_decl.output {
|
let kind = if x.width == 1 {
|
||||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
// Dual can only be used for f32/f64 ret.
|
||||||
FnRetTy::Default(span) => {
|
// In that case we return now a tuple with two floats.
|
||||||
panic!("Did not expect Default ret ty: {:?}", span);
|
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
||||||
}
|
} else {
|
||||||
|
// We have to return [T; width+1], +1 for the primal return.
|
||||||
|
let anon_const = rustc_ast::AnonConst {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
value: ecx.expr_usize(span, 1 + x.width as usize),
|
||||||
|
};
|
||||||
|
TyKind::Array(ty.clone(), anon_const)
|
||||||
};
|
};
|
||||||
// Dual can only be used for f32/f64 ret.
|
|
||||||
// In that case we return now a tuple with two floats.
|
|
||||||
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
|
|
||||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||||
d_decl.output = FnRetTy::Ty(ty);
|
d_decl.output = FnRetTy::Ty(ty);
|
||||||
}
|
}
|
||||||
if let DiffActivity::DualOnly = x.ret_activity {
|
if let DiffActivity::DualOnly = x.ret_activity {
|
||||||
// No need to change the return type,
|
// No need to change the return type,
|
||||||
// we will just return the shadow in place
|
// we will just return the shadow in place of the primal return.
|
||||||
// of the primal return.
|
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||||
|
if x.width > 1 {
|
||||||
|
let anon_const = rustc_ast::AnonConst {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
value: ecx.expr_usize(span, x.width as usize),
|
||||||
|
};
|
||||||
|
let kind = TyKind::Array(ty.clone(), anon_const);
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||||
|
d_decl.output = FnRetTy::Ty(ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -202,6 +202,14 @@ mod autodiff {
|
||||||
pub(crate) mode: String,
|
pub(crate) mode: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_width)]
|
||||||
|
pub(crate) struct AutoDiffInvalidWidth {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) width: u128,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Diagnostic)]
|
#[derive(Diagnostic)]
|
||||||
#[diag(builtin_macros_autodiff)]
|
#[diag(builtin_macros_autodiff)]
|
||||||
pub(crate) struct AutoDiffInvalidApplication {
|
pub(crate) struct AutoDiffInvalidApplication {
|
||||||
|
|
|
@ -860,7 +860,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
|
Some(AutoDiffAttrs { mode, width: 1, ret_activity, input_activity: arg_activities })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn provide(providers: &mut Providers) {
|
pub(crate) fn provide(providers: &mut Providers) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue