1
Fork 0

Rollup merge of #139351 - EnzymeAD:autodiff-batching2, r=oli-obk

Autodiff batching2

~I will rebase it once my first PR landed.~ done.
This autodiff batch mode is more similar to scalar autodiff, since it still only takes one shadow argument.
However, that argument is supposed to be `width` times larger.

r? `@oli-obk`

Tracking:

- https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
Matthias Krüger 2025-04-17 21:53:23 +02:00 committed by GitHub
commit 87a163523f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 220 additions and 28 deletions

View file

@ -2,7 +2,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};
@ -22,23 +22,51 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};
let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}