Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
commit
5863b426b9
5 changed files with 73 additions and 2 deletions
|
@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
|
|||
}
|
||||
|
||||
if attrs.width == 1 {
|
||||
todo!("Handle sret for scalar ad");
|
||||
// Enzyme returns a struct of style:
|
||||
// `{ original_ret(if requested), float, float, ... }`
|
||||
let mut struct_elements = vec![];
|
||||
if attrs.has_primal_ret() {
|
||||
struct_elements.push(inner_ret_ty);
|
||||
}
|
||||
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
|
||||
// and therefore part of the return struct.
|
||||
let param_tys = cx.func_params_types(fn_ty);
|
||||
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
|
||||
if matches!(act, DiffActivity::Active) {
|
||||
// Now find the float type at position i based on the fn_ty,
|
||||
// to know what (f16/f32/f64/...) to add to the struct.
|
||||
struct_elements.push(param_ty);
|
||||
}
|
||||
}
|
||||
ret_ty = cx.type_struct(&struct_elements, false);
|
||||
} else {
|
||||
// First we check if we also have to deal with the primal return.
|
||||
match attrs.mode {
|
||||
|
@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
|
|||
// now store the result of the enzyme call into the sret pointer.
|
||||
let sret_ptr = outer_args[0];
|
||||
let call_ty = cx.val_ty(call);
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
|
||||
if attrs.width == 1 {
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
|
||||
} else {
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
|
||||
}
|
||||
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
|
||||
}
|
||||
builder.ret_void();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue