fix fwd-mode autodiff case
This commit is contained in:
parent
335151f8bb
commit
70b9ba3d6e
1 changed files with 8 additions and 3 deletions
|
@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
|
|||
let mut activity_pos = 0;
|
||||
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
|
||||
while activity_pos < inputs.len() {
|
||||
let activity = inputs[activity_pos as usize];
|
||||
let diff_activity = inputs[activity_pos as usize];
|
||||
// Duplicated arguments received a shadow argument, into which enzyme will write the
|
||||
// gradient.
|
||||
let (activity, duplicated): (&Metadata, bool) = match activity {
|
||||
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
|
||||
DiffActivity::None => panic!("not a valid input activity"),
|
||||
DiffActivity::Const => (enzyme_const, false),
|
||||
DiffActivity::Active => (enzyme_out, false),
|
||||
|
@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
|
|||
// A duplicated pointer will have the following two outer_fn arguments:
|
||||
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
|
||||
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
||||
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
|
||||
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
|
||||
assert!(
|
||||
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
|
||||
);
|
||||
}
|
||||
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
|
||||
args.push(next_outer_arg);
|
||||
outer_pos += 2;
|
||||
activity_pos += 1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue