1
Fork 0

Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk

add sret handling for scalar autodiff

r? `@oli-obk`

Fixing one of the todo's which I left in my previous batching PR.
This one handles sret for scalar autodiff.  `sret` mostly shows up when we try to return a lot of scalar floats.
People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier.

Tracking:

- https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
Stuart Cook 2025-04-07 22:29:21 +10:00 committed by GitHub
commit 5863b426b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 73 additions and 2 deletions

View file

@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
pub input_activity: Vec<DiffActivity>,
}
impl AutoDiffAttrs {
pub fn has_primal_ret(&self) -> bool {
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
}
}
impl DiffMode {
pub fn is_rev(&self) -> bool {
matches!(self, DiffMode::Reverse)