fix fwd-mode autodiff case

This commit is contained in:
Manuel Drehwald 2025-02-05 18:47:23 -05:00
parent 335151f8bb
commit 70b9ba3d6e

View file

@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
let mut activity_pos = 0;
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
while activity_pos < inputs.len() {
let activity = inputs[activity_pos as usize];
let diff_activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the
// gradient.
let (activity, duplicated): (&Metadata, bool) = match activity {
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
DiffActivity::None => panic!("not a valid input activity"),
DiffActivity::Const => (enzyme_const, false),
DiffActivity::Active => (enzyme_out, false),
@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
// A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...).
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
assert!(
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
);
}
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
args.push(next_outer_arg);
outer_pos += 2;
activity_pos += 1;