fix fwd-mode autodiff case
This commit is contained in:
parent
335151f8bb
commit
70b9ba3d6e
1 changed files with 8 additions and 3 deletions
|
@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
|
||||||
let mut activity_pos = 0;
|
let mut activity_pos = 0;
|
||||||
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
|
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
|
||||||
while activity_pos < inputs.len() {
|
while activity_pos < inputs.len() {
|
||||||
let activity = inputs[activity_pos as usize];
|
let diff_activity = inputs[activity_pos as usize];
|
||||||
// Duplicated arguments received a shadow argument, into which enzyme will write the
|
// Duplicated arguments received a shadow argument, into which enzyme will write the
|
||||||
// gradient.
|
// gradient.
|
||||||
let (activity, duplicated): (&Metadata, bool) = match activity {
|
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
|
||||||
DiffActivity::None => panic!("not a valid input activity"),
|
DiffActivity::None => panic!("not a valid input activity"),
|
||||||
DiffActivity::Const => (enzyme_const, false),
|
DiffActivity::Const => (enzyme_const, false),
|
||||||
DiffActivity::Active => (enzyme_out, false),
|
DiffActivity::Active => (enzyme_out, false),
|
||||||
|
@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
|
||||||
// A duplicated pointer will have the following two outer_fn arguments:
|
// A duplicated pointer will have the following two outer_fn arguments:
|
||||||
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
|
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
|
||||||
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
||||||
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
|
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
|
||||||
|
assert!(
|
||||||
|
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
|
||||||
args.push(next_outer_arg);
|
args.push(next_outer_arg);
|
||||||
outer_pos += 2;
|
outer_pos += 2;
|
||||||
activity_pos += 1;
|
activity_pos += 1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue