working dupv and dupvonly for fwd mode
This commit is contained in:
parent
d4f880f8ce
commit
a68ae0cbc1
5 changed files with 107 additions and 28 deletions
|
@ -799,8 +799,19 @@ mod llvm_enzyme {
|
|||
d_inputs.push(shadow_arg.clone());
|
||||
}
|
||||
}
|
||||
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||
for i in 0..x.width {
|
||||
DiffActivity::Dual
|
||||
| DiffActivity::DualOnly
|
||||
| DiffActivity::Dualv
|
||||
| DiffActivity::DualvOnly => {
|
||||
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
|
||||
// Enzyme to not expect N arguments, but one argument (which is instead larger).
|
||||
let iterations =
|
||||
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
|
||||
1
|
||||
} else {
|
||||
x.width
|
||||
};
|
||||
for i in 0..iterations {
|
||||
let mut shadow_arg = arg.clone();
|
||||
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||
ident.name
|
||||
|
@ -823,7 +834,7 @@ mod llvm_enzyme {
|
|||
DiffActivity::Const => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
DiffActivity::None | DiffActivity::FakeActivitySize => {
|
||||
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
|
||||
panic!("Should not happen");
|
||||
}
|
||||
}
|
||||
|
@ -887,8 +898,8 @@ mod llvm_enzyme {
|
|||
}
|
||||
};
|
||||
|
||||
if let DiffActivity::Dual = x.ret_activity {
|
||||
let kind = if x.width == 1 {
|
||||
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
|
||||
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
|
||||
// Dual can only be used for f32/f64 ret.
|
||||
// In that case we return now a tuple with two floats.
|
||||
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
|
||||
|
@ -903,7 +914,7 @@ mod llvm_enzyme {
|
|||
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||
d_decl.output = FnRetTy::Ty(ty);
|
||||
}
|
||||
if let DiffActivity::DualOnly = x.ret_activity {
|
||||
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
|
||||
// No need to change the return type,
|
||||
// we will just return the shadow in place of the primal return.
|
||||
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue