Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
commit
5863b426b9
5 changed files with 73 additions and 2 deletions
|
@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
|
|||
pub input_activity: Vec<DiffActivity>,
|
||||
}
|
||||
|
||||
impl AutoDiffAttrs {
|
||||
pub fn has_primal_ret(&self) -> bool {
|
||||
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
|
||||
}
|
||||
}
|
||||
|
||||
impl DiffMode {
|
||||
pub fn is_rev(&self) -> bool {
|
||||
matches!(self, DiffMode::Reverse)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue