1
Fork 0

working dupv and dupvonly for fwd mode

This commit is contained in:
Manuel Drehwald 2025-04-05 03:10:19 -04:00
parent d4f880f8ce
commit a68ae0cbc1
5 changed files with 107 additions and 28 deletions

View file

@ -799,8 +799,19 @@ mod llvm_enzyme {
d_inputs.push(shadow_arg.clone());
}
}
DiffActivity::Dual | DiffActivity::DualOnly => {
for i in 0..x.width {
DiffActivity::Dual
| DiffActivity::DualOnly
| DiffActivity::Dualv
| DiffActivity::DualvOnly => {
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
// Enzyme to not expect N arguments, but one argument (which is instead larger).
let iterations =
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
1
} else {
x.width
};
for i in 0..iterations {
let mut shadow_arg = arg.clone();
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
@ -823,7 +834,7 @@ mod llvm_enzyme {
DiffActivity::Const => {
// Nothing to do here.
}
DiffActivity::None | DiffActivity::FakeActivitySize => {
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
panic!("Should not happen");
}
}
@ -887,8 +898,8 @@ mod llvm_enzyme {
}
};
if let DiffActivity::Dual = x.ret_activity {
let kind = if x.width == 1 {
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
// Dual can only be used for f32/f64 ret.
// In that case we return now a tuple with two floats.
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
@ -903,7 +914,7 @@ mod llvm_enzyme {
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
d_decl.output = FnRetTy::Ty(ty);
}
if let DiffActivity::DualOnly = x.ret_activity {
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
// No need to change the return type,
// we will just return the shadow in place of the primal return.
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]