handle sret for scalar autodiff
This commit is contained in:
parent
2fa8b11f09
commit
d6467d34ae
2 changed files with 28 additions and 2 deletions
|
@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
|
|||
}
|
||||
|
||||
if attrs.width == 1 {
|
||||
todo!("Handle sret for scalar ad");
|
||||
// Enzyme returns a struct of style:
|
||||
// `{ original_ret(if requested), float, float, ... }`
|
||||
let mut struct_elements = vec![];
|
||||
if attrs.has_primal_ret() {
|
||||
struct_elements.push(inner_ret_ty);
|
||||
}
|
||||
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
|
||||
// and therefore part of the return struct.
|
||||
let param_tys = cx.func_params_types(fn_ty);
|
||||
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
|
||||
if matches!(act, DiffActivity::Active) {
|
||||
// Now find the float type at position i based on the fn_ty,
|
||||
// to know what (f16/f32/f64/...) to add to the struct.
|
||||
struct_elements.push(param_ty);
|
||||
}
|
||||
}
|
||||
ret_ty = cx.type_struct(&struct_elements, false);
|
||||
} else {
|
||||
// First we check if we also have to deal with the primal return.
|
||||
match attrs.mode {
|
||||
|
@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
|
|||
// now store the result of the enzyme call into the sret pointer.
|
||||
let sret_ptr = outer_args[0];
|
||||
let call_ty = cx.val_ty(call);
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
|
||||
if attrs.width == 1 {
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
|
||||
} else {
|
||||
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
|
||||
}
|
||||
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
|
||||
}
|
||||
builder.ret_void();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue