Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching Enzyme supports batching, which is especially known from the ML side when training neural networks. There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights. That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations. Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size, and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of ```rs for i in 0..100 { df(x[i], y[i], 1234); } ``` You can now do ```rs for i in 0..100.step_by(4) { df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234); } ``` which will give the same results, but allows better compiler optimizations. See the testcase for details. There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days. I will also add more tests for both modes. For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU I'll also add some other docs to the dev guide and user docs in another PR. r? ghost Tracking: - https://github.com/rust-lang/rust/issues/124509 - https://github.com/rust-lang/rust/issues/135283
This commit is contained in:
commit
c6bf3a01ef
21 changed files with 728 additions and 234 deletions
|
@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
|
|||
/// e.g. in the [JAX
|
||||
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
||||
pub mode: DiffMode,
|
||||
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
|
||||
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
|
||||
/// - Calling the function 50 times with a batch size of 2
|
||||
/// - Calling the function 25 times with a batch size of 4,
|
||||
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
|
||||
/// cache locality, better re-usal of primal values, and other optimizations.
|
||||
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
|
||||
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
|
||||
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
|
||||
/// experiments for now and focus on documenting the implications of a large width.
|
||||
pub width: u32,
|
||||
pub ret_activity: DiffActivity,
|
||||
pub input_activity: Vec<DiffActivity>,
|
||||
}
|
||||
|
@ -222,6 +233,7 @@ impl AutoDiffAttrs {
|
|||
pub const fn error() -> Self {
|
||||
AutoDiffAttrs {
|
||||
mode: DiffMode::Error,
|
||||
width: 0,
|
||||
ret_activity: DiffActivity::None,
|
||||
input_activity: Vec::new(),
|
||||
}
|
||||
|
@ -229,6 +241,7 @@ impl AutoDiffAttrs {
|
|||
pub fn source() -> Self {
|
||||
AutoDiffAttrs {
|
||||
mode: DiffMode::Source,
|
||||
width: 0,
|
||||
ret_activity: DiffActivity::None,
|
||||
input_activity: Vec::new(),
|
||||
}
|
||||
|
|
|
@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
|
|||
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
||||
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
||||
|
||||
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
|
||||
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
||||
.label = not applicable here
|
||||
.label2 = not a `struct`, `enum` or `union`
|
||||
|
|
|
@ -12,12 +12,12 @@ mod llvm_enzyme {
|
|||
valid_ty_for_activity,
|
||||
};
|
||||
use rustc_ast::ptr::P;
|
||||
use rustc_ast::token::{Token, TokenKind};
|
||||
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
|
||||
use rustc_ast::tokenstream::*;
|
||||
use rustc_ast::visit::AssocCtxt::*;
|
||||
use rustc_ast::{
|
||||
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
|
||||
PatKind, TyKind,
|
||||
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
||||
MetaItemInner, PatKind, QSelf, TyKind,
|
||||
};
|
||||
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
||||
|
@ -45,6 +45,16 @@ mod llvm_enzyme {
|
|||
}
|
||||
}
|
||||
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
||||
if let Some(l) = x.lit() {
|
||||
match l.kind {
|
||||
ast::LitKind::Int(val, _) => {
|
||||
// get an Ident from a lit
|
||||
return rustc_span::Ident::from_str(val.get().to_string().as_str());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let segments = &x.meta_item().unwrap().path.segments;
|
||||
assert!(segments.len() == 1);
|
||||
segments[0].ident
|
||||
|
@ -54,6 +64,14 @@ mod llvm_enzyme {
|
|||
first_ident(x).name.to_string()
|
||||
}
|
||||
|
||||
fn width(x: &MetaItemInner) -> Option<u128> {
|
||||
let lit = x.lit()?;
|
||||
match lit.kind {
|
||||
ast::LitKind::Int(x, _) => Some(x.get()),
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_ast(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
meta_item: &ThinVec<MetaItemInner>,
|
||||
|
@ -65,9 +83,32 @@ mod llvm_enzyme {
|
|||
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||
return AutoDiffAttrs::error();
|
||||
};
|
||||
|
||||
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
||||
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
||||
let mut first_activity = 2;
|
||||
|
||||
let width = if let [_, _, x, ..] = &meta_item[..]
|
||||
&& let Some(x) = width(x)
|
||||
{
|
||||
first_activity = 3;
|
||||
match x.try_into() {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
||||
span: meta_item[2].span(),
|
||||
width: x,
|
||||
});
|
||||
return AutoDiffAttrs::error();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let mut activities: Vec<DiffActivity> = vec![];
|
||||
let mut errors = false;
|
||||
for x in &meta_item[2..] {
|
||||
for x in &meta_item[first_activity..] {
|
||||
let activity_str = name(&x);
|
||||
let res = DiffActivity::from_str(&activity_str);
|
||||
match res {
|
||||
|
@ -98,7 +139,20 @@ mod llvm_enzyme {
|
|||
(&DiffActivity::None, activities.as_slice())
|
||||
};
|
||||
|
||||
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
|
||||
AutoDiffAttrs {
|
||||
mode,
|
||||
width,
|
||||
ret_activity: *ret_activity,
|
||||
input_activity: input_activity.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
|
||||
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||
let val = first_ident(t);
|
||||
let t = Token::from_ast_ident(val);
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
}
|
||||
|
||||
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||
|
@ -195,27 +249,49 @@ mod llvm_enzyme {
|
|||
|
||||
// create TokenStream from vec elemtents:
|
||||
// meta_item doesn't have a .tokens field
|
||||
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||
let mut ts: Vec<TokenTree> = vec![];
|
||||
if meta_item_vec.len() < 2 {
|
||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||
// input and output args.
|
||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||
return vec![item];
|
||||
} else {
|
||||
for t in meta_item_vec.clone()[1..].iter() {
|
||||
let val = first_ident(t);
|
||||
let t = Token::from_ast_ident(val);
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
}
|
||||
}
|
||||
|
||||
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
||||
|
||||
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||
// If it is not given, we default to 1 (scalar mode).
|
||||
let start_position;
|
||||
let kind: LitKind = LitKind::Integer;
|
||||
let symbol;
|
||||
if meta_item_vec.len() >= 3
|
||||
&& let Some(width) = width(&meta_item_vec[2])
|
||||
{
|
||||
start_position = 3;
|
||||
symbol = Symbol::intern(&width.to_string());
|
||||
} else {
|
||||
start_position = 2;
|
||||
symbol = sym::integer(1);
|
||||
}
|
||||
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
|
||||
for t in meta_item_vec.clone()[start_position..].iter() {
|
||||
meta_item_inner_to_ts(t, &mut ts);
|
||||
}
|
||||
|
||||
if !has_ret {
|
||||
// We don't want users to provide a return activity if the function doesn't return anything.
|
||||
// For simplicity, we just add a dummy token to the end of the list.
|
||||
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma, Spacing::Alone));
|
||||
}
|
||||
// We remove the last, trailing comma.
|
||||
ts.pop();
|
||||
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||
|
||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||
|
@ -470,6 +546,8 @@ mod llvm_enzyme {
|
|||
return body;
|
||||
}
|
||||
|
||||
// Everything from here onwards just tries to fullfil the return type. Fun!
|
||||
|
||||
// having an active-only return means we'll drop the original return type.
|
||||
// So that can be treated identical to not having one in the first place.
|
||||
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
||||
|
@ -497,86 +575,65 @@ mod llvm_enzyme {
|
|||
return body;
|
||||
}
|
||||
|
||||
let mut exprs = ThinVec::<P<ast::Expr>>::new();
|
||||
if primal_ret {
|
||||
// We have both primal ret and active floats.
|
||||
// primal ret is first, by construction.
|
||||
exprs.push(primal_call);
|
||||
}
|
||||
|
||||
// Now construct default placeholder for each active float.
|
||||
// Is there something nicer than f32::default() and f64::default()?
|
||||
let mut exprs: P<ast::Expr> = primal_call.clone();
|
||||
let d_ret_ty = match d_sig.decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
panic!("Did not expect Default ret ty: {:?}", span);
|
||||
}
|
||||
};
|
||||
let mut d_ret_ty = match d_ret_ty.kind.clone() {
|
||||
TyKind::Tup(ref tys) => tys.clone(),
|
||||
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
|
||||
if let [segment] = &segments[..]
|
||||
&& segment.args.is_none()
|
||||
{
|
||||
let id = vec![segments[0].ident];
|
||||
let kind = TyKind::Path(None, ecx.path(span, id));
|
||||
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||
thin_vec![ty]
|
||||
} else {
|
||||
panic!("Expected tuple or simple path return type");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// We messed up construction of d_sig
|
||||
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
|
||||
}
|
||||
};
|
||||
|
||||
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
|
||||
assert!(d_ret_ty.len() == 2);
|
||||
// both should be identical, by construction
|
||||
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
|
||||
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
|
||||
assert!(arg == arg2);
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs.push(default_call_expr);
|
||||
} else if x.mode.is_rev() {
|
||||
if primal_ret {
|
||||
// We have extra handling above for the primal ret
|
||||
d_ret_ty = d_ret_ty[1..].to_vec().into();
|
||||
}
|
||||
|
||||
for arg in d_ret_ty.iter() {
|
||||
let arg = arg.kind.is_simple_path().unwrap();
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
if x.mode.is_fwd() {
|
||||
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
||||
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
||||
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
|
||||
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
|
||||
if x.ret_activity == DiffActivity::Const {
|
||||
// Here we call the primal function, since our dummy function has the same return
|
||||
// type due to the Const return activity.
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||
} else {
|
||||
let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
|
||||
let y =
|
||||
ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
|
||||
let default_call_expr = ecx.expr(span, y);
|
||||
let default_call_expr =
|
||||
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs.push(default_call_expr);
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
|
||||
}
|
||||
} else if x.mode.is_rev() {
|
||||
if x.width == 1 {
|
||||
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
|
||||
match d_ret_ty.kind {
|
||||
TyKind::Tup(ref args) => {
|
||||
// We have a tuple return type. We need to create a tuple of the same size
|
||||
// and fill it with default values.
|
||||
let mut exprs2 = thin_vec![exprs];
|
||||
for arg in args.iter().skip(1) {
|
||||
let arg = arg.kind.is_simple_path().unwrap();
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
let default_call_expr =
|
||||
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs2.push(default_call_expr);
|
||||
}
|
||||
exprs = ecx.expr_tuple(new_decl_span, exprs2);
|
||||
}
|
||||
_ => {
|
||||
// Interestingly, even the `-> ArbitraryType` case
|
||||
// ends up getting matched and handled correctly above,
|
||||
// so we don't have to handle any other case for now.
|
||||
panic!("Unsupported return type: {:?}", d_ret_ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||
} else {
|
||||
unreachable!("Unsupported mode: {:?}", x.mode);
|
||||
}
|
||||
|
||||
let ret: P<ast::Expr>;
|
||||
match &exprs[..] {
|
||||
[] => {
|
||||
assert!(!has_ret(&d_sig.decl.output));
|
||||
// We don't have to match the return type.
|
||||
return body;
|
||||
}
|
||||
[arg] => {
|
||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
|
||||
}
|
||||
args => {
|
||||
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
|
||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
|
||||
}
|
||||
}
|
||||
assert!(has_ret(&d_sig.decl.output));
|
||||
body.stmts.push(ecx.stmt_expr(ret));
|
||||
body.stmts.push(ecx.stmt_expr(exprs));
|
||||
|
||||
body
|
||||
}
|
||||
|
@ -684,50 +741,55 @@ mod llvm_enzyme {
|
|||
match activity {
|
||||
DiffActivity::Active => {
|
||||
act_ret.push(arg.ty.clone());
|
||||
// if width =/= 1, then push [arg.ty; width] to act_ret
|
||||
}
|
||||
DiffActivity::ActiveOnly => {
|
||||
// We will add the active scalar to the return type.
|
||||
// This is handled later.
|
||||
}
|
||||
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
||||
let mut shadow_arg = arg.clone();
|
||||
// We += into the shadow in reverse mode.
|
||||
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("d{}", old_name);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg);
|
||||
for i in 0..x.width {
|
||||
let mut shadow_arg = arg.clone();
|
||||
// We += into the shadow in reverse mode.
|
||||
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("d{}_{}", old_name, i);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("b{}", old_name);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg);
|
||||
for i in 0..x.width {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("b{}_{}", old_name, i);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Const => {
|
||||
// Nothing to do here.
|
||||
|
@ -783,23 +845,48 @@ mod llvm_enzyme {
|
|||
d_decl.inputs = d_inputs.into();
|
||||
|
||||
if x.mode.is_fwd() {
|
||||
let ty = match d_decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
// We want to return std::hint::black_box(()).
|
||||
let kind = TyKind::Tup(ThinVec::new());
|
||||
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty.clone());
|
||||
assert!(matches!(x.ret_activity, DiffActivity::None));
|
||||
// this won't be used below, so any type would be fine.
|
||||
ty
|
||||
}
|
||||
};
|
||||
|
||||
if let DiffActivity::Dual = x.ret_activity {
|
||||
let ty = match d_decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
panic!("Did not expect Default ret ty: {:?}", span);
|
||||
}
|
||||
let kind = if x.width == 1 {
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
||||
} else {
|
||||
// We have to return [T; width+1], +1 for the primal return.
|
||||
let anon_const = rustc_ast::AnonConst {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
value: ecx.expr_usize(span, 1 + x.width as usize),
|
||||
};
|
||||
TyKind::Array(ty.clone(), anon_const)
|
||||
};
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
if let DiffActivity::DualOnly = x.ret_activity {
|
||||
// No need to change the return type,
|
||||
// we will just return the shadow in place
|
||||
// of the primal return.
|
||||
// we will just return the shadow in place of the primal return.
|
||||
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||
if x.width > 1 {
|
||||
let anon_const = rustc_ast::AnonConst {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
value: ecx.expr_usize(span, x.width as usize),
|
||||
};
|
||||
let kind = TyKind::Array(ty.clone(), anon_const);
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -202,6 +202,14 @@ mod autodiff {
|
|||
pub(crate) mode: String,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff_width)]
|
||||
pub(crate) struct AutoDiffInvalidWidth {
|
||||
#[primary_span]
|
||||
pub(crate) span: Span,
|
||||
pub(crate) width: u128,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff)]
|
||||
pub(crate) struct AutoDiffInvalidApplication {
|
||||
|
|
|
@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
|
|||
}
|
||||
// We handle this below
|
||||
config::AutoDiff::PrintModAfter => {}
|
||||
// We handle this below
|
||||
config::AutoDiff::PrintModFinal => {}
|
||||
// This is required and already checked
|
||||
config::AutoDiff::Enable => {}
|
||||
}
|
||||
|
@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
|
|||
}
|
||||
|
||||
if cfg!(llvm_enzyme) && enable_ad {
|
||||
// This is the post-autodiff IR, mainly used for testing and educational purposes.
|
||||
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
|
||||
let opt_stage = llvm::OptStage::FatLTO;
|
||||
let stage = write::AutodiffStage::PostAD;
|
||||
unsafe {
|
||||
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
||||
}
|
||||
|
||||
// This is the final IR, so people should be able to inspect the optimized autodiff output.
|
||||
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
||||
// This is the final IR, so people should be able to inspect the optimized autodiff output,
|
||||
// for manual inspection.
|
||||
if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,10 @@ use std::ptr;
|
|||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
|
||||
use rustc_codegen_ssa::ModuleCodegen;
|
||||
use rustc_codegen_ssa::back::write::ModuleConfig;
|
||||
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
|
||||
use rustc_errors::FatalError;
|
||||
use rustc_middle::bug;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::back::write::llvm_err;
|
||||
|
@ -18,21 +20,42 @@ use crate::value::Value;
|
|||
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
|
||||
|
||||
fn get_params(fnc: &Value) -> Vec<&Value> {
|
||||
let param_num = llvm::LLVMCountParams(fnc) as usize;
|
||||
let mut fnc_args: Vec<&Value> = vec![];
|
||||
fnc_args.reserve(param_num);
|
||||
unsafe {
|
||||
let param_num = llvm::LLVMCountParams(fnc) as usize;
|
||||
let mut fnc_args: Vec<&Value> = vec![];
|
||||
fnc_args.reserve(param_num);
|
||||
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
|
||||
fnc_args.set_len(param_num);
|
||||
fnc_args
|
||||
}
|
||||
fnc_args
|
||||
}
|
||||
|
||||
fn has_sret(fnc: &Value) -> bool {
|
||||
let num_args = llvm::LLVMCountParams(fnc) as usize;
|
||||
if num_args == 0 {
|
||||
false
|
||||
} else {
|
||||
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
|
||||
}
|
||||
}
|
||||
|
||||
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
|
||||
// original inputs, as well as metadata and the additional shadow arguments.
|
||||
// This function matches the arguments from the outer function to the inner enzyme call.
|
||||
//
|
||||
// This function also considers that Rust level arguments not always match the llvm-ir level
|
||||
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
|
||||
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
|
||||
// need to match those.
|
||||
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
|
||||
// using iterators and peek()?
|
||||
fn match_args_from_caller_to_enzyme<'ll>(
|
||||
cx: &SimpleCx<'ll>,
|
||||
width: u32,
|
||||
args: &mut Vec<&'ll llvm::Value>,
|
||||
inputs: &[DiffActivity],
|
||||
outer_args: &[&'ll llvm::Value],
|
||||
has_sret: bool,
|
||||
) {
|
||||
debug!("matching autodiff arguments");
|
||||
// We now handle the issue that Rust level arguments not always match the llvm-ir level
|
||||
|
@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
|||
let mut outer_pos: usize = 0;
|
||||
let mut activity_pos = 0;
|
||||
|
||||
if has_sret {
|
||||
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
|
||||
// inner function will still return something. We increase our outer_pos by one,
|
||||
// and once we're done with all other args we will take the return of the inner call and
|
||||
// update the sret pointer with it
|
||||
outer_pos = 1;
|
||||
}
|
||||
|
||||
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
|
||||
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
|
||||
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
|
||||
|
@ -92,23 +123,20 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
|||
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
|
||||
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
|
||||
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
|
||||
assert!(unsafe {
|
||||
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
|
||||
});
|
||||
let next_outer_arg2 = outer_args[outer_pos + 2];
|
||||
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
|
||||
assert!(unsafe {
|
||||
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
|
||||
});
|
||||
let next_outer_arg3 = outer_args[outer_pos + 3];
|
||||
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
|
||||
assert!(unsafe {
|
||||
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
|
||||
});
|
||||
args.push(next_outer_arg2);
|
||||
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
|
||||
|
||||
for i in 0..(width as usize) {
|
||||
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
|
||||
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
|
||||
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
|
||||
let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
|
||||
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
|
||||
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
|
||||
args.push(next_outer_arg2);
|
||||
}
|
||||
args.push(cx.get_metadata_value(enzyme_const));
|
||||
args.push(next_outer_arg);
|
||||
outer_pos += 4;
|
||||
outer_pos += 2 + 2 * width as usize;
|
||||
activity_pos += 2;
|
||||
} else {
|
||||
// A duplicated pointer will have the following two outer_fn arguments:
|
||||
|
@ -116,15 +144,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
|||
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
||||
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
|
||||
{
|
||||
assert!(
|
||||
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
|
||||
== llvm::TypeKind::Pointer
|
||||
);
|
||||
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
|
||||
}
|
||||
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
|
||||
args.push(next_outer_arg);
|
||||
outer_pos += 2;
|
||||
activity_pos += 1;
|
||||
|
||||
// Now, if width > 1, we need to account for that
|
||||
for _ in 1..width {
|
||||
let next_outer_arg = outer_args[outer_pos];
|
||||
args.push(next_outer_arg);
|
||||
outer_pos += 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We do not differentiate with resprect to this argument.
|
||||
|
@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
|
|||
}
|
||||
}
|
||||
|
||||
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
|
||||
// arguments. We do however need to declare them with their correct return type.
|
||||
// We already figured the correct return type out in our frontend, when generating the outer_fn,
|
||||
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
|
||||
// Beyond sret, this article describes our challenges nicely:
|
||||
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
|
||||
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
|
||||
fn compute_enzyme_fn_ty<'ll>(
|
||||
cx: &SimpleCx<'ll>,
|
||||
attrs: &AutoDiffAttrs,
|
||||
fn_to_diff: &'ll Value,
|
||||
outer_fn: &'ll Value,
|
||||
) -> &'ll llvm::Type {
|
||||
let fn_ty = cx.get_type_of_global(outer_fn);
|
||||
let mut ret_ty = cx.get_return_type(fn_ty);
|
||||
|
||||
let has_sret = has_sret(outer_fn);
|
||||
|
||||
if has_sret {
|
||||
// Now we don't just forward the return type, so we have to figure it out based on the
|
||||
// primal return type, in combination with the autodiff settings.
|
||||
let fn_ty = cx.get_type_of_global(fn_to_diff);
|
||||
let inner_ret_ty = cx.get_return_type(fn_ty);
|
||||
|
||||
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
|
||||
if inner_ret_ty == void_ty {
|
||||
// This indicates that even the inner function has an sret.
|
||||
// Right now I only look for an sret in the outer function.
|
||||
// This *probably* needs some extra handling, but I never ran
|
||||
// into such a case. So I'll wait for user reports to have a test case.
|
||||
bug!("sret in inner function");
|
||||
}
|
||||
|
||||
if attrs.width == 1 {
|
||||
todo!("Handle sret for scalar ad");
|
||||
} else {
|
||||
// First we check if we also have to deal with the primal return.
|
||||
match attrs.mode {
|
||||
DiffMode::Forward => match attrs.ret_activity {
|
||||
DiffActivity::Dual => {
|
||||
let arr_ty =
|
||||
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
|
||||
ret_ty = arr_ty;
|
||||
}
|
||||
DiffActivity::DualOnly => {
|
||||
let arr_ty =
|
||||
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
|
||||
ret_ty = arr_ty;
|
||||
}
|
||||
DiffActivity::Const => {
|
||||
todo!("Not sure, do we need to do something here?");
|
||||
}
|
||||
_ => {
|
||||
bug!("unreachable");
|
||||
}
|
||||
},
|
||||
DiffMode::Reverse => {
|
||||
todo!("Handle sret for reverse mode");
|
||||
}
|
||||
_ => {
|
||||
bug!("unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LLVM can figure out the input types on it's own, so we take a shortcut here.
|
||||
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
|
||||
}
|
||||
|
||||
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
|
||||
/// function with expected naming and calling conventions[^1] which will be
|
||||
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
|
||||
|
@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
|
|||
// }
|
||||
// ```
|
||||
unsafe {
|
||||
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
|
||||
// arguments. We do however need to declare them with their correct return type.
|
||||
// We already figured the correct return type out in our frontend, when generating the outer_fn,
|
||||
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
|
||||
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
|
||||
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
|
||||
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
|
||||
|
||||
// LLVM can figure out the input types on it's own, so we take a shortcut here.
|
||||
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
|
||||
|
||||
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
|
||||
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
|
||||
// think a bit more about what should go here.
|
||||
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
|
||||
let ad_fn = declare_simple_fn(
|
||||
|
@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
|
|||
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
|
||||
args.push(cx.get_metadata_value(enzyme_primal_ret));
|
||||
}
|
||||
if attrs.width > 1 {
|
||||
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
|
||||
args.push(cx.get_metadata_value(enzyme_width));
|
||||
args.push(cx.get_const_i64(attrs.width as u64));
|
||||
}
|
||||
|
||||
let has_sret = has_sret(outer_fn);
|
||||
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
|
||||
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
|
||||
match_args_from_caller_to_enzyme(
|
||||
&cx,
|
||||
attrs.width,
|
||||
&mut args,
|
||||
&attrs.input_activity,
|
||||
&outer_args,
|
||||
has_sret,
|
||||
);
|
||||
|
||||
let call = builder.call(enzyme_ty, ad_fn, &args, None);
|
||||
|
||||
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
|
||||
// metadata attachted to it, but we just created this code oota. Given that the
|
||||
// metadata attached to it, but we just created this code oota. Given that the
|
||||
// differentiated function already has partly confusing metadata, and given that this
|
||||
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
|
||||
// dummy code which we inserted at a higher level.
|
||||
|
@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
|
|||
// Now that we copied the metadata, get rid of dummy code.
|
||||
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
|
||||
|
||||
if cx.val_ty(call) == cx.type_void() {
|
||||
if cx.val_ty(call) == cx.type_void() || has_sret {
|
||||
if has_sret {
|
||||
// This is what we already have in our outer_fn (shortened):
|
||||
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
|
||||
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
|
||||
// <Here we are, we want to add the following two lines>
|
||||
// store [4 x double] %7, ptr %0, align 8
|
||||
// ret void
|
||||
// }
|
||||
|
||||
// now store the result of the enzyme call into the sret pointer.
|
||||
let sret_ptr = outer_args[0];
|
||||
let call_ty = cx.val_ty(call);
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
|
||||
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
|
||||
}
|
||||
builder.ret_void();
|
||||
} else {
|
||||
builder.ret(call);
|
||||
|
@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
|
|||
if !diff_items.is_empty()
|
||||
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
|
||||
{
|
||||
let dcx = cgcx.create_dcx();
|
||||
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
|
||||
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
|
||||
}
|
||||
|
||||
// Before dumping the module, we want all the TypeTrees to become part of the module.
|
||||
|
|
|
@ -430,7 +430,7 @@ impl<'ll> CodegenCx<'ll, '_> {
|
|||
let val_llty = self.val_ty(v);
|
||||
|
||||
let g = self.get_static_inner(def_id, val_llty);
|
||||
let llty = llvm::LLVMGlobalGetValueType(g);
|
||||
let llty = self.get_type_of_global(g);
|
||||
|
||||
let g = if val_llty == llty {
|
||||
g
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::str;
|
|||
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
|
||||
use rustc_codegen_ssa::back::versioned_llvm_target;
|
||||
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
use rustc_codegen_ssa::errors as ssa_errors;
|
||||
use rustc_codegen_ssa::traits::*;
|
||||
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
|
||||
|
@ -38,7 +39,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
|
|||
use crate::llvm::Metadata;
|
||||
use crate::type_::Type;
|
||||
use crate::value::Value;
|
||||
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
|
||||
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
|
||||
|
||||
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
|
||||
/// However, there are various cx related functions which we want to be available to the builder and
|
||||
|
@ -643,7 +644,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
|
|||
llvm::set_section(g, c"llvm.metadata");
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ll> SimpleCx<'ll> {
|
||||
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
|
||||
assert_eq!(self.type_kind(ty), TypeKind::Function);
|
||||
unsafe { llvm::LLVMGetReturnType(ty) }
|
||||
}
|
||||
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
|
||||
unsafe { llvm::LLVMGlobalGetValueType(val) }
|
||||
}
|
||||
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
|
||||
common::val_ty(v)
|
||||
}
|
||||
}
|
||||
impl<'ll> SimpleCx<'ll> {
|
||||
pub(crate) fn new(
|
||||
llmod: &'ll llvm::Module,
|
||||
|
@ -660,6 +672,13 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
|
|||
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
|
||||
}
|
||||
|
||||
// FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
|
||||
// onto a trait that is also implemented for GenericCx.
|
||||
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
|
||||
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
|
||||
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
|
||||
}
|
||||
|
||||
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
|
||||
let name = SmallCStr::new(name);
|
||||
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
use libc::{c_char, c_uint};
|
||||
|
||||
use super::MetadataKindId;
|
||||
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
|
||||
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
|
||||
use crate::llvm::Bool;
|
||||
|
||||
#[link(name = "llvm-wrapper", kind = "static")]
|
||||
|
@ -17,6 +17,8 @@ unsafe extern "C" {
|
|||
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
|
||||
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
|
||||
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
|
||||
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
|
||||
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
|
|
@ -1180,7 +1180,7 @@ unsafe extern "C" {
|
|||
|
||||
// Operations on parameters
|
||||
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
|
||||
pub(crate) fn LLVMCountParams(Fn: &Value) -> c_uint;
|
||||
pub(crate) safe fn LLVMCountParams(Fn: &Value) -> c_uint;
|
||||
pub(crate) fn LLVMGetParam(Fn: &Value, Index: c_uint) -> &Value;
|
||||
|
||||
// Operations on basic blocks
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::str::FromStr;
|
|||
|
||||
use rustc_abi::ExternAbi;
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_ast::{MetaItem, MetaItemInner, attr};
|
||||
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
|
||||
use rustc_attr_parsing::ReprAttr::ReprAlign;
|
||||
use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
|
@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
return Some(AutoDiffAttrs::source());
|
||||
}
|
||||
|
||||
let [mode, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
|
||||
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
|
||||
};
|
||||
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
|
||||
p1.segments.first().unwrap().ident
|
||||
|
@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
}
|
||||
};
|
||||
|
||||
let width: u32 = match width_meta {
|
||||
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
|
||||
let w = p1.segments.first().unwrap().ident;
|
||||
match w.as_str().parse() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(w.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
}
|
||||
MetaItemInner::Lit(lit) => {
|
||||
if let LitKind::Int(val, _) = lit.kind {
|
||||
match val.get().try_into() {
|
||||
Ok(val) => val,
|
||||
Err(_) => {
|
||||
span_bug!(lit.span, "rustc_autodiff width should fit u32");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
span_bug!(lit.span, "rustc_autodiff width should be an integer");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// First read the ret symbol from the attribute
|
||||
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
|
||||
p1.segments.first().unwrap().ident
|
||||
|
@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
|||
}
|
||||
}
|
||||
|
||||
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
|
||||
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
|
||||
}
|
||||
|
||||
pub(crate) fn provide(providers: &mut Providers) {
|
||||
|
|
|
@ -384,6 +384,12 @@ static inline void AddAttributes(T *t, unsigned Index, LLVMAttributeRef *Attrs,
|
|||
t->setAttributes(PALNew);
|
||||
}
|
||||
|
||||
extern "C" bool LLVMRustHasAttributeAtIndex(LLVMValueRef Fn, unsigned Index,
|
||||
LLVMRustAttributeKind RustAttr) {
|
||||
Function *F = unwrap<Function>(Fn);
|
||||
return F->hasParamAttribute(Index, fromRust(RustAttr));
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
|
||||
LLVMAttributeRef *Attrs,
|
||||
size_t AttrsLen) {
|
||||
|
@ -636,6 +642,10 @@ static InlineAsm::AsmDialect fromRust(LLVMRustAsmDialect Dialect) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" uint64_t LLVMRustGetArrayNumElements(LLVMTypeRef Ty) {
|
||||
return unwrap(Ty)->getArrayNumElements();
|
||||
}
|
||||
|
||||
extern "C" LLVMValueRef
|
||||
LLVMRustInlineAsm(LLVMTypeRef Ty, char *AsmString, size_t AsmStringLen,
|
||||
char *Constraints, size_t ConstraintsLen,
|
||||
|
|
|
@ -237,10 +237,12 @@ pub enum AutoDiff {
|
|||
PrintPerf,
|
||||
/// Print intermediate IR generation steps
|
||||
PrintSteps,
|
||||
/// Print the whole module, before running opts.
|
||||
/// Print the module, before running autodiff.
|
||||
PrintModBefore,
|
||||
/// Print the module after Enzyme differentiated everything.
|
||||
/// Print the module after running autodiff.
|
||||
PrintModAfter,
|
||||
/// Print the module after running autodiff and optimizations.
|
||||
PrintModFinal,
|
||||
|
||||
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
|
||||
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
|
||||
|
|
|
@ -711,7 +711,7 @@ mod desc {
|
|||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||
pub(crate) const parse_list_with_polarity: &str =
|
||||
"a comma-separated list of strings, with elements beginning with + or -";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||
pub(crate) const parse_number: &str = "a number";
|
||||
|
@ -1359,6 +1359,7 @@ pub mod parse {
|
|||
"PrintSteps" => AutoDiff::PrintSteps,
|
||||
"PrintModBefore" => AutoDiff::PrintModBefore,
|
||||
"PrintModAfter" => AutoDiff::PrintModAfter,
|
||||
"PrintModFinal" => AutoDiff::PrintModFinal,
|
||||
"LooseTypes" => AutoDiff::LooseTypes,
|
||||
"Inline" => AutoDiff::Inline,
|
||||
_ => {
|
||||
|
@ -2093,6 +2094,7 @@ options! {
|
|||
`=PrintSteps`
|
||||
`=PrintModBefore`
|
||||
`=PrintModAfter`
|
||||
`=PrintModFinal`
|
||||
`=LooseTypes`
|
||||
`=Inline`
|
||||
Multiple options can be combined with commas."),
|
||||
|
|
|
@ -11,7 +11,7 @@ fn square(x: &f64) -> f64 {
|
|||
x * x
|
||||
}
|
||||
|
||||
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
|
||||
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
|
||||
// CHECK-NEXT:invertstart:
|
||||
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
|
||||
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
|
||||
|
@ -22,7 +22,7 @@ fn square(x: &f64) -> f64 {
|
|||
// CHECK-NEXT:}
|
||||
|
||||
fn main() {
|
||||
let x = 3.0;
|
||||
let x = std::hint::black_box(3.0);
|
||||
let output = square(&x);
|
||||
assert_eq!(9.0, output);
|
||||
|
||||
|
|
116
tests/codegen/autodiffv.rs
Normal file
116
tests/codegen/autodiffv.rs
Normal file
|
@ -0,0 +1,116 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
//
|
||||
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
|
||||
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
|
||||
// through LLVM's O3 pipeline, which will remove most of the noise.
|
||||
// However, our integration test could also be affected by changes in how rustc lowers MIR into
|
||||
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
|
||||
// reduce this test to only match the first lines and the ret instructions.
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff;
|
||||
|
||||
#[autodiff(d_square3, Forward, Dual, DualOnly)]
|
||||
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
|
||||
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
|
||||
#[no_mangle]
|
||||
fn square(x: &f32) -> f32 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// d_sqaure2
|
||||
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
|
||||
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
|
||||
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
|
||||
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
|
||||
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
|
||||
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
|
||||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
|
||||
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
|
||||
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
|
||||
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
|
||||
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
|
||||
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
|
||||
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
|
||||
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
|
||||
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: ret [4 x float] %19
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// d_square3, the extra float is the original return value (x * x)
|
||||
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
|
||||
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
|
||||
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
|
||||
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
|
||||
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
|
||||
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
|
||||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
|
||||
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
|
||||
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
|
||||
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
|
||||
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
|
||||
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
|
||||
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
|
||||
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
|
||||
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
|
||||
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
|
||||
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
|
||||
// CHECK-NEXT: ret { float, [4 x float] } %21
|
||||
// CHECK-NEXT: }
|
||||
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
let output = square(&x);
|
||||
dbg!(&output);
|
||||
assert_eq!(9.0, output);
|
||||
dbg!(square(&x));
|
||||
|
||||
let mut df_dx1 = 1.0;
|
||||
let mut df_dx2 = 2.0;
|
||||
let mut df_dx3 = 3.0;
|
||||
let mut df_dx4 = 0.0;
|
||||
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
|
||||
dbg!(o1, o2, o3, o4);
|
||||
let [output2, o1, o2, o3, o4] =
|
||||
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
|
||||
dbg!(o1, o2, o3, o4);
|
||||
assert_eq!(output, output2);
|
||||
assert!((6.0 - o1).abs() < 1e-10);
|
||||
assert!((12.0 - o2).abs() < 1e-10);
|
||||
assert!((18.0 - o3).abs() < 1e-10);
|
||||
assert!((0.0 - o4).abs() < 1e-10);
|
||||
assert_eq!(1.0, df_dx1);
|
||||
assert_eq!(2.0, df_dx2);
|
||||
assert_eq!(3.0, df_dx3);
|
||||
assert_eq!(0.0, df_dx4);
|
||||
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
|
||||
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
|
||||
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
|
||||
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
|
||||
}
|
|
@ -25,27 +25,31 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
|||
|
||||
// We want to be sure that the same function can be differentiated in different ways
|
||||
|
||||
|
||||
// Make sure, that we add the None for the default return.
|
||||
|
||||
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, Dual, Const, Dual,)]
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
|
||||
#[inline(never)]
|
||||
pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
|
||||
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f1(x, y));
|
||||
::core::hint::black_box((bx,));
|
||||
::core::hint::black_box((f1(x, y), f64::default()))
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<(f64, f64)>::default())
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, Dual, Const, Const,)]
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f2(x, y));
|
||||
::core::hint::black_box((bx,));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f2(x, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
|
@ -53,20 +57,20 @@ pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
|||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, Dual, Const, Const,)]
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f3(x, y));
|
||||
::core::hint::black_box((bx,));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f3(x, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f4() {}
|
||||
#[rustc_autodiff(Forward, None)]
|
||||
#[rustc_autodiff(Forward, 1, None)]
|
||||
#[inline(never)]
|
||||
pub fn df4() {
|
||||
pub fn df4() -> () {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f4());
|
||||
::core::hint::black_box(());
|
||||
|
@ -76,28 +80,82 @@ pub fn df4() {
|
|||
pub fn f5(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, Const, Dual, Const,)]
|
||||
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
|
||||
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((by,));
|
||||
::core::hint::black_box((by_0,));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
}
|
||||
#[rustc_autodiff(Forward, Dual, Const, Const,)]
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
|
||||
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((bx,));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
}
|
||||
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((dx, dret));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
}
|
||||
struct DoesNotImplDefault;
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f6() -> DoesNotImplDefault {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df6() -> DoesNotImplDefault {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f6());
|
||||
::core::hint::black_box(());
|
||||
::core::hint::black_box(f6())
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f7(x: f32) -> () {}
|
||||
#[rustc_autodiff(Forward, 1, Const, None)]
|
||||
#[inline(never)]
|
||||
pub fn df7(x: f32) -> () {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f7(x));
|
||||
::core::hint::black_box(());
|
||||
}
|
||||
#[no_mangle]
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Forward, 4, Dual, Dual)]
|
||||
#[inline(never)]
|
||||
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 5usize] {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
|
||||
::core::hint::black_box(<[f32; 5usize]>::default())
|
||||
}
|
||||
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 4usize] {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
|
||||
::core::hint::black_box(<[f32; 4usize]>::default())
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<f32>::default())
|
||||
}
|
||||
fn main() {}
|
||||
|
|
|
@ -36,4 +36,22 @@ pub fn f5(x: &[f64], y: f64) -> f64 {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
struct DoesNotImplDefault;
|
||||
#[autodiff(df6, Forward, Const)]
|
||||
pub fn f6() -> DoesNotImplDefault {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// Make sure, that we add the None for the default return.
|
||||
#[autodiff(df7, Forward, Const)]
|
||||
pub fn f7(x: f32) -> () {}
|
||||
|
||||
#[autodiff(f8_1, Forward, Dual, DualOnly)]
|
||||
#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
|
||||
#[autodiff(f8_3, Forward, 4, Dual, Dual)]
|
||||
#[no_mangle]
|
||||
fn f8(x: &f32) -> f32 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
@ -28,18 +28,18 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
|||
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f1(x, y));
|
||||
::core::hint::black_box((dx, dret));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f1(x, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f2() {}
|
||||
#[rustc_autodiff(Reverse, None)]
|
||||
#[rustc_autodiff(Reverse, 1, None)]
|
||||
#[inline(never)]
|
||||
pub fn df2() {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
|
@ -51,12 +51,12 @@ pub fn df2() {
|
|||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f3(x, y));
|
||||
::core::hint::black_box((dx, dret));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f3(x, y))
|
||||
}
|
||||
enum Foo { Reverse, }
|
||||
|
@ -64,7 +64,7 @@ use Foo::Reverse;
|
|||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Reverse, Const, None)]
|
||||
#[rustc_autodiff(Reverse, 1, Const, None)]
|
||||
#[inline(never)]
|
||||
pub fn df4(x: f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
|
@ -76,11 +76,11 @@ pub fn df4(x: f32) {
|
|||
pub fn f5(x: *const f32, y: &f32) {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
|
||||
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
|
||||
#[inline(never)]
|
||||
pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
|
||||
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((dx, dy));
|
||||
::core::hint::black_box((dx_0, dy_0));
|
||||
}
|
||||
fn main() {}
|
||||
|
|
|
@ -177,4 +177,11 @@ fn f21(x: f32) -> f32 {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
struct DoesNotImplDefault;
|
||||
#[autodiff(df22, Forward, Dual)]
|
||||
pub fn f22() -> DoesNotImplDefault {
|
||||
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
@ -19,32 +19,24 @@ error: expected 1 activities, but found 2
|
|||
|
|
||||
LL | #[autodiff(df3, Reverse, Duplicated, Const)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: expected 1 activities, but found 0
|
||||
--> $DIR/autodiff_illegal.rs:27:1
|
||||
|
|
||||
LL | #[autodiff(df4, Reverse)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: Dual can not be used in Reverse Mode
|
||||
--> $DIR/autodiff_illegal.rs:34:1
|
||||
|
|
||||
LL | #[autodiff(df5, Reverse, Dual)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: Duplicated can not be used in Forward Mode
|
||||
--> $DIR/autodiff_illegal.rs:41:1
|
||||
|
|
||||
LL | #[autodiff(df6, Forward, Duplicated)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: Duplicated can not be used for this type
|
||||
--> $DIR/autodiff_illegal.rs:42:14
|
||||
|
@ -107,7 +99,6 @@ LL | #[autodiff(fn_exists, Reverse, Active)]
|
|||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `fn_exists` redefined here
|
||||
|
|
||||
= note: `fn_exists` must be defined only once in the value namespace of this module
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: autodiff requires at least a name and mode
|
||||
--> $DIR/autodiff_illegal.rs:95:1
|
||||
|
@ -135,42 +126,49 @@ error: invalid return activity Active in Forward Mode
|
|||
|
|
||||
LL | #[autodiff(df19, Forward, Dual, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: invalid return activity Dual in Reverse Mode
|
||||
--> $DIR/autodiff_illegal.rs:167:1
|
||||
|
|
||||
LL | #[autodiff(df20, Reverse, Active, Dual)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: invalid return activity Duplicated in Reverse Mode
|
||||
--> $DIR/autodiff_illegal.rs:174:1
|
||||
|
|
||||
LL | #[autodiff(df21, Reverse, Active, Duplicated)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type `MyFloat`
|
||||
--> $DIR/autodiff_illegal.rs:130:1
|
||||
|
|
||||
LL | #[autodiff(df15, Reverse, Active, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type `F64Trans`
|
||||
--> $DIR/autodiff_illegal.rs:154:1
|
||||
|
|
||||
LL | #[autodiff(df18, Reverse, Active, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
|
||||
|
||||
error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
|
||||
--> $DIR/autodiff_illegal.rs:181:1
|
||||
|
|
||||
LL | struct DoesNotImplDefault;
|
||||
| ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
|
||||
LL | #[autodiff(df22, Forward, Dual)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
|
||||
|
|
||||
= note: the following trait bounds were not satisfied:
|
||||
`DoesNotImplDefault: Default`
|
||||
which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
|
||||
help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
|
||||
|
|
||||
LL + #[derive(Default)]
|
||||
LL | struct DoesNotImplDefault;
|
||||
|
|
||||
= note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
error: aborting due to 22 previous errors
|
||||
error: aborting due to 23 previous errors
|
||||
|
||||
Some errors have detailed explanations: E0428, E0433, E0658.
|
||||
Some errors have detailed explanations: E0428, E0433, E0599, E0658.
|
||||
For more information about an error, try `rustc --explain E0428`.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue