1
Fork 0

fix fwd-mode autodiff case

This commit is contained in:
Manuel Drehwald 2025-02-05 18:47:23 -05:00
parent 335151f8bb
commit 70b9ba3d6e

View file

@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
let mut activity_pos = 0; let mut activity_pos = 0;
let outer_args: Vec<&llvm::Value> = get_params(outer_fn); let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
while activity_pos < inputs.len() { while activity_pos < inputs.len() {
let activity = inputs[activity_pos as usize]; let diff_activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the // Duplicated arguments received a shadow argument, into which enzyme will write the
// gradient. // gradient.
let (activity, duplicated): (&Metadata, bool) = match activity { let (activity, duplicated): (&Metadata, bool) = match diff_activity {
DiffActivity::None => panic!("not a valid input activity"), DiffActivity::None => panic!("not a valid input activity"),
DiffActivity::Const => (enzyme_const, false), DiffActivity::Const => (enzyme_const, false),
DiffActivity::Active => (enzyme_out, false), DiffActivity::Active => (enzyme_out, false),
@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
// A duplicated pointer will have the following two outer_fn arguments: // A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...). // (..., metadata! enzyme_dup, ptr, ptr, ...).
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
assert!(
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
);
}
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
args.push(next_outer_arg); args.push(next_outer_arg);
outer_pos += 2; outer_pos += 2;
activity_pos += 1; activity_pos += 1;