Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching Enzyme supports batching, which is especially known from the ML side when training neural networks. There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights. That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations. Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size, and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of ```rs for i in 0..100 { df(x[i], y[i], 1234); } ``` You can now do ```rs for i in 0..100.step_by(4) { df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234); } ``` which will give the same results, but allows better compiler optimizations. See the testcase for details. There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days. I will also add more tests for both modes. For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU I'll also add some other docs to the dev guide and user docs in another PR. r? ghost Tracking: - https://github.com/rust-lang/rust/issues/124509 - https://github.com/rust-lang/rust/issues/135283
This commit is contained in:
commit
c6bf3a01ef
21 changed files with 728 additions and 234 deletions
|
@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
|
|||
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
||||
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
||||
|
||||
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
|
||||
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
||||
.label = not applicable here
|
||||
.label2 = not a `struct`, `enum` or `union`
|
||||
|
|
|
@ -12,12 +12,12 @@ mod llvm_enzyme {
|
|||
valid_ty_for_activity,
|
||||
};
|
||||
use rustc_ast::ptr::P;
|
||||
use rustc_ast::token::{Token, TokenKind};
|
||||
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
|
||||
use rustc_ast::tokenstream::*;
|
||||
use rustc_ast::visit::AssocCtxt::*;
|
||||
use rustc_ast::{
|
||||
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
|
||||
PatKind, TyKind,
|
||||
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
|
||||
MetaItemInner, PatKind, QSelf, TyKind,
|
||||
};
|
||||
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||
use rustc_span::{Ident, Span, Symbol, kw, sym};
|
||||
|
@ -45,6 +45,16 @@ mod llvm_enzyme {
|
|||
}
|
||||
}
|
||||
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
|
||||
if let Some(l) = x.lit() {
|
||||
match l.kind {
|
||||
ast::LitKind::Int(val, _) => {
|
||||
// get an Ident from a lit
|
||||
return rustc_span::Ident::from_str(val.get().to_string().as_str());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let segments = &x.meta_item().unwrap().path.segments;
|
||||
assert!(segments.len() == 1);
|
||||
segments[0].ident
|
||||
|
@ -54,6 +64,14 @@ mod llvm_enzyme {
|
|||
first_ident(x).name.to_string()
|
||||
}
|
||||
|
||||
fn width(x: &MetaItemInner) -> Option<u128> {
|
||||
let lit = x.lit()?;
|
||||
match lit.kind {
|
||||
ast::LitKind::Int(x, _) => Some(x.get()),
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_ast(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
meta_item: &ThinVec<MetaItemInner>,
|
||||
|
@ -65,9 +83,32 @@ mod llvm_enzyme {
|
|||
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||
return AutoDiffAttrs::error();
|
||||
};
|
||||
|
||||
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
||||
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
||||
let mut first_activity = 2;
|
||||
|
||||
let width = if let [_, _, x, ..] = &meta_item[..]
|
||||
&& let Some(x) = width(x)
|
||||
{
|
||||
first_activity = 3;
|
||||
match x.try_into() {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
||||
span: meta_item[2].span(),
|
||||
width: x,
|
||||
});
|
||||
return AutoDiffAttrs::error();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let mut activities: Vec<DiffActivity> = vec![];
|
||||
let mut errors = false;
|
||||
for x in &meta_item[2..] {
|
||||
for x in &meta_item[first_activity..] {
|
||||
let activity_str = name(&x);
|
||||
let res = DiffActivity::from_str(&activity_str);
|
||||
match res {
|
||||
|
@ -98,7 +139,20 @@ mod llvm_enzyme {
|
|||
(&DiffActivity::None, activities.as_slice())
|
||||
};
|
||||
|
||||
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
|
||||
AutoDiffAttrs {
|
||||
mode,
|
||||
width,
|
||||
ret_activity: *ret_activity,
|
||||
input_activity: input_activity.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
|
||||
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||
let val = first_ident(t);
|
||||
let t = Token::from_ast_ident(val);
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
}
|
||||
|
||||
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||
|
@ -195,27 +249,49 @@ mod llvm_enzyme {
|
|||
|
||||
// create TokenStream from vec elemtents:
|
||||
// meta_item doesn't have a .tokens field
|
||||
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||
let mut ts: Vec<TokenTree> = vec![];
|
||||
if meta_item_vec.len() < 2 {
|
||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||
// input and output args.
|
||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||
return vec![item];
|
||||
} else {
|
||||
for t in meta_item_vec.clone()[1..].iter() {
|
||||
let val = first_ident(t);
|
||||
let t = Token::from_ast_ident(val);
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
}
|
||||
}
|
||||
|
||||
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
||||
|
||||
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||
// If it is not given, we default to 1 (scalar mode).
|
||||
let start_position;
|
||||
let kind: LitKind = LitKind::Integer;
|
||||
let symbol;
|
||||
if meta_item_vec.len() >= 3
|
||||
&& let Some(width) = width(&meta_item_vec[2])
|
||||
{
|
||||
start_position = 3;
|
||||
symbol = Symbol::intern(&width.to_string());
|
||||
} else {
|
||||
start_position = 2;
|
||||
symbol = sym::integer(1);
|
||||
}
|
||||
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
|
||||
for t in meta_item_vec.clone()[start_position..].iter() {
|
||||
meta_item_inner_to_ts(t, &mut ts);
|
||||
}
|
||||
|
||||
if !has_ret {
|
||||
// We don't want users to provide a return activity if the function doesn't return anything.
|
||||
// For simplicity, we just add a dummy token to the end of the list.
|
||||
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
||||
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||
ts.push(TokenTree::Token(comma, Spacing::Alone));
|
||||
}
|
||||
// We remove the last, trailing comma.
|
||||
ts.pop();
|
||||
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||
|
||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||
|
@ -470,6 +546,8 @@ mod llvm_enzyme {
|
|||
return body;
|
||||
}
|
||||
|
||||
// Everything from here onwards just tries to fullfil the return type. Fun!
|
||||
|
||||
// having an active-only return means we'll drop the original return type.
|
||||
// So that can be treated identical to not having one in the first place.
|
||||
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
||||
|
@ -497,86 +575,65 @@ mod llvm_enzyme {
|
|||
return body;
|
||||
}
|
||||
|
||||
let mut exprs = ThinVec::<P<ast::Expr>>::new();
|
||||
if primal_ret {
|
||||
// We have both primal ret and active floats.
|
||||
// primal ret is first, by construction.
|
||||
exprs.push(primal_call);
|
||||
}
|
||||
|
||||
// Now construct default placeholder for each active float.
|
||||
// Is there something nicer than f32::default() and f64::default()?
|
||||
let mut exprs: P<ast::Expr> = primal_call.clone();
|
||||
let d_ret_ty = match d_sig.decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
panic!("Did not expect Default ret ty: {:?}", span);
|
||||
}
|
||||
};
|
||||
let mut d_ret_ty = match d_ret_ty.kind.clone() {
|
||||
TyKind::Tup(ref tys) => tys.clone(),
|
||||
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
|
||||
if let [segment] = &segments[..]
|
||||
&& segment.args.is_none()
|
||||
{
|
||||
let id = vec![segments[0].ident];
|
||||
let kind = TyKind::Path(None, ecx.path(span, id));
|
||||
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||
thin_vec![ty]
|
||||
} else {
|
||||
panic!("Expected tuple or simple path return type");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// We messed up construction of d_sig
|
||||
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
|
||||
}
|
||||
};
|
||||
|
||||
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
|
||||
assert!(d_ret_ty.len() == 2);
|
||||
// both should be identical, by construction
|
||||
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
|
||||
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
|
||||
assert!(arg == arg2);
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs.push(default_call_expr);
|
||||
} else if x.mode.is_rev() {
|
||||
if primal_ret {
|
||||
// We have extra handling above for the primal ret
|
||||
d_ret_ty = d_ret_ty[1..].to_vec().into();
|
||||
}
|
||||
|
||||
for arg in d_ret_ty.iter() {
|
||||
let arg = arg.kind.is_simple_path().unwrap();
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
if x.mode.is_fwd() {
|
||||
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
||||
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
||||
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
|
||||
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
|
||||
if x.ret_activity == DiffActivity::Const {
|
||||
// Here we call the primal function, since our dummy function has the same return
|
||||
// type due to the Const return activity.
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||
} else {
|
||||
let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
|
||||
let y =
|
||||
ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
|
||||
let default_call_expr = ecx.expr(span, y);
|
||||
let default_call_expr =
|
||||
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs.push(default_call_expr);
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
|
||||
}
|
||||
} else if x.mode.is_rev() {
|
||||
if x.width == 1 {
|
||||
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
|
||||
match d_ret_ty.kind {
|
||||
TyKind::Tup(ref args) => {
|
||||
// We have a tuple return type. We need to create a tuple of the same size
|
||||
// and fill it with default values.
|
||||
let mut exprs2 = thin_vec![exprs];
|
||||
for arg in args.iter().skip(1) {
|
||||
let arg = arg.kind.is_simple_path().unwrap();
|
||||
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||
let tmp = ecx.def_site_path(&sl);
|
||||
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||
let default_call_expr =
|
||||
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||
exprs2.push(default_call_expr);
|
||||
}
|
||||
exprs = ecx.expr_tuple(new_decl_span, exprs2);
|
||||
}
|
||||
_ => {
|
||||
// Interestingly, even the `-> ArbitraryType` case
|
||||
// ends up getting matched and handled correctly above,
|
||||
// so we don't have to handle any other case for now.
|
||||
panic!("Unsupported return type: {:?}", d_ret_ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
|
||||
} else {
|
||||
unreachable!("Unsupported mode: {:?}", x.mode);
|
||||
}
|
||||
|
||||
let ret: P<ast::Expr>;
|
||||
match &exprs[..] {
|
||||
[] => {
|
||||
assert!(!has_ret(&d_sig.decl.output));
|
||||
// We don't have to match the return type.
|
||||
return body;
|
||||
}
|
||||
[arg] => {
|
||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
|
||||
}
|
||||
args => {
|
||||
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
|
||||
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
|
||||
}
|
||||
}
|
||||
assert!(has_ret(&d_sig.decl.output));
|
||||
body.stmts.push(ecx.stmt_expr(ret));
|
||||
body.stmts.push(ecx.stmt_expr(exprs));
|
||||
|
||||
body
|
||||
}
|
||||
|
@ -684,50 +741,55 @@ mod llvm_enzyme {
|
|||
match activity {
|
||||
DiffActivity::Active => {
|
||||
act_ret.push(arg.ty.clone());
|
||||
// if width =/= 1, then push [arg.ty; width] to act_ret
|
||||
}
|
||||
DiffActivity::ActiveOnly => {
|
||||
// We will add the active scalar to the return type.
|
||||
// This is handled later.
|
||||
}
|
||||
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
||||
let mut shadow_arg = arg.clone();
|
||||
// We += into the shadow in reverse mode.
|
||||
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("d{}", old_name);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg);
|
||||
for i in 0..x.width {
|
||||
let mut shadow_arg = arg.clone();
|
||||
// We += into the shadow in reverse mode.
|
||||
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("d{}_{}", old_name, i);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("b{}", old_name);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg);
|
||||
for i in 0..x.width {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
} else {
|
||||
debug!("{:#?}", &shadow_arg.pat);
|
||||
panic!("not an ident?");
|
||||
};
|
||||
let name: String = format!("b{}_{}", old_name, i);
|
||||
new_inputs.push(name.clone());
|
||||
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||
shadow_arg.pat = P(ast::Pat {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||
span: shadow_arg.pat.span,
|
||||
tokens: shadow_arg.pat.tokens.clone(),
|
||||
});
|
||||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Const => {
|
||||
// Nothing to do here.
|
||||
|
@ -783,23 +845,48 @@ mod llvm_enzyme {
|
|||
d_decl.inputs = d_inputs.into();
|
||||
|
||||
if x.mode.is_fwd() {
|
||||
let ty = match d_decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
// We want to return std::hint::black_box(()).
|
||||
let kind = TyKind::Tup(ThinVec::new());
|
||||
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty.clone());
|
||||
assert!(matches!(x.ret_activity, DiffActivity::None));
|
||||
// this won't be used below, so any type would be fine.
|
||||
ty
|
||||
}
|
||||
};
|
||||
|
||||
if let DiffActivity::Dual = x.ret_activity {
|
||||
let ty = match d_decl.output {
|
||||
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||
FnRetTy::Default(span) => {
|
||||
panic!("Did not expect Default ret ty: {:?}", span);
|
||||
}
|
||||
let kind = if x.width == 1 {
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
||||
} else {
|
||||
// We have to return [T; width+1], +1 for the primal return.
|
||||
let anon_const = rustc_ast::AnonConst {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
value: ecx.expr_usize(span, 1 + x.width as usize),
|
||||
};
|
||||
TyKind::Array(ty.clone(), anon_const)
|
||||
};
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
if let DiffActivity::DualOnly = x.ret_activity {
|
||||
// No need to change the return type,
|
||||
// we will just return the shadow in place
|
||||
// of the primal return.
|
||||
// we will just return the shadow in place of the primal return.
|
||||
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||
if x.width > 1 {
|
||||
let anon_const = rustc_ast::AnonConst {
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
value: ecx.expr_usize(span, x.width as usize),
|
||||
};
|
||||
let kind = TyKind::Array(ty.clone(), anon_const);
|
||||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -202,6 +202,14 @@ mod autodiff {
|
|||
pub(crate) mode: String,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff_width)]
|
||||
pub(crate) struct AutoDiffInvalidWidth {
|
||||
#[primary_span]
|
||||
pub(crate) span: Span,
|
||||
pub(crate) width: u128,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff)]
|
||||
pub(crate) struct AutoDiffInvalidApplication {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue