Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk

Autodiff batching

Enzyme supports batching, which is especially known from the ML side when training neural networks.
There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights.
That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations.

Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size,
and then each Dual/Duplicated argument has not one, but N shadow arguments.  So instead of
```rs
for i in 0..100 {
   df(x[i], y[i], 1234);
}
```
You can now do
```rs
for i in 0..100.step_by(4) {
   df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234);
}
```
which will give the same results, but allows better compiler optimizations. See the testcase for details.

There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days.

I will also add more tests for both modes.

For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU
I'll also add some other docs to the dev guide and user docs in another PR.

r? ghost

Tracking:

- https://github.com/rust-lang/rust/issues/124509
- https://github.com/rust-lang/rust/issues/135283
This commit is contained in:
Stuart Cook 2025-04-05 13:18:13 +11:00 committed by GitHub
commit c6bf3a01ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 728 additions and 234 deletions

View file

@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
}
// We handle this below
config::AutoDiff::PrintModAfter => {}
// We handle this below
config::AutoDiff::PrintModFinal => {}
// This is required and already checked
config::AutoDiff::Enable => {}
}
@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
}
if cfg!(llvm_enzyme) && enable_ad {
// This is the post-autodiff IR, mainly used for testing and educational purposes.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}
// This is the final IR, so people should be able to inspect the optimized autodiff output.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
// This is the final IR, so people should be able to inspect the optimized autodiff output,
// for manual inspection.
if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
}

View file

@ -3,8 +3,10 @@ use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_errors::FatalError;
use rustc_middle::bug;
use tracing::{debug, trace};
use crate::back::write::llvm_err;
@ -18,21 +20,42 @@ use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
fn get_params(fnc: &Value) -> Vec<&Value> {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
unsafe {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
fnc_args.set_len(param_num);
fnc_args
}
fnc_args
}
fn has_sret(fnc: &Value) -> bool {
let num_args = llvm::LLVMCountParams(fnc) as usize;
if num_args == 0 {
false
} else {
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
}
}
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
// original inputs, as well as metadata and the additional shadow arguments.
// This function matches the arguments from the outer function to the inner enzyme call.
//
// This function also considers that Rust level arguments not always match the llvm-ir level
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
// need to match those.
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
cx: &SimpleCx<'ll>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
outer_args: &[&'ll llvm::Value],
has_sret: bool,
) {
debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
let mut outer_pos: usize = 0;
let mut activity_pos = 0;
if has_sret {
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
// inner function will still return something. We increase our outer_pos by one,
// and once we're done with all other args we will take the return of the inner call and
// update the sret pointer with it
outer_pos = 1;
}
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@ -92,23 +123,20 @@ fn match_args_from_caller_to_enzyme<'ll>(
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
});
let next_outer_arg2 = outer_args[outer_pos + 2];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
});
let next_outer_arg3 = outer_args[outer_pos + 3];
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
});
args.push(next_outer_arg2);
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
for i in 0..(width as usize) {
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
args.push(next_outer_arg2);
}
args.push(cx.get_metadata_value(enzyme_const));
args.push(next_outer_arg);
outer_pos += 4;
outer_pos += 2 + 2 * width as usize;
activity_pos += 2;
} else {
// A duplicated pointer will have the following two outer_fn arguments:
@ -116,15 +144,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
// (..., metadata! enzyme_dup, ptr, ptr, ...).
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
{
assert!(
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
== llvm::TypeKind::Pointer
);
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
}
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
args.push(next_outer_arg);
outer_pos += 2;
activity_pos += 1;
// Now, if width > 1, we need to account for that
for _ in 1..width {
let next_outer_arg = outer_args[outer_pos];
args.push(next_outer_arg);
outer_pos += 1;
}
}
} else {
// We do not differentiate with resprect to this argument.
@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
}
}
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
// Beyond sret, this article describes our challenges nicely:
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
fn compute_enzyme_fn_ty<'ll>(
cx: &SimpleCx<'ll>,
attrs: &AutoDiffAttrs,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
) -> &'ll llvm::Type {
let fn_ty = cx.get_type_of_global(outer_fn);
let mut ret_ty = cx.get_return_type(fn_ty);
let has_sret = has_sret(outer_fn);
if has_sret {
// Now we don't just forward the return type, so we have to figure it out based on the
// primal return type, in combination with the autodiff settings.
let fn_ty = cx.get_type_of_global(fn_to_diff);
let inner_ret_ty = cx.get_return_type(fn_ty);
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
if inner_ret_ty == void_ty {
// This indicates that even the inner function has an sret.
// Right now I only look for an sret in the outer function.
// This *probably* needs some extra handling, but I never ran
// into such a case. So I'll wait for user reports to have a test case.
bug!("sret in inner function");
}
if attrs.width == 1 {
todo!("Handle sret for scalar ad");
} else {
// First we check if we also have to deal with the primal return.
match attrs.mode {
DiffMode::Forward => match attrs.ret_activity {
DiffActivity::Dual => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
ret_ty = arr_ty;
}
DiffActivity::DualOnly => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
ret_ty = arr_ty;
}
DiffActivity::Const => {
todo!("Not sure, do we need to do something here?");
}
_ => {
bug!("unreachable");
}
},
DiffMode::Reverse => {
todo!("Handle sret for reverse mode");
}
_ => {
bug!("unreachable");
}
}
}
}
// LLVM can figure out the input types on it's own, so we take a shortcut here.
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
}
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
/// function with expected naming and calling conventions[^1] which will be
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
// }
// ```
unsafe {
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
// LLVM can figure out the input types on it's own, so we take a shortcut here.
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_simple_fn(
@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
if attrs.width > 1 {
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_i64(attrs.width as u64));
}
let has_sret = has_sret(outer_fn);
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
match_args_from_caller_to_enzyme(
&cx,
attrs.width,
&mut args,
&attrs.input_activity,
&outer_args,
has_sret,
);
let call = builder.call(enzyme_ty, ad_fn, &args, None);
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attachted to it, but we just created this code oota. Given that the
// metadata attached to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
// dummy code which we inserted at a higher level.
@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
if cx.val_ty(call) == cx.type_void() {
if cx.val_ty(call) == cx.type_void() || has_sret {
if has_sret {
// This is what we already have in our outer_fn (shortened):
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
// <Here we are, we want to add the following two lines>
// store [4 x double] %7, ptr %0, align 8
// ret void
// }
// now store the result of the enzyme call into the sret pointer.
let sret_ptr = outer_args[0];
let call_ty = cx.val_ty(call);
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
}
builder.ret_void();
} else {
builder.ret(call);
@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
if !diff_items.is_empty()
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
{
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}
// Before dumping the module, we want all the TypeTrees to become part of the module.

View file

@ -430,7 +430,7 @@ impl<'ll> CodegenCx<'ll, '_> {
let val_llty = self.val_ty(v);
let g = self.get_static_inner(def_id, val_llty);
let llty = llvm::LLVMGlobalGetValueType(g);
let llty = self.get_type_of_global(g);
let g = if val_llty == llty {
g

View file

@ -8,6 +8,7 @@ use std::str;
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
use rustc_codegen_ssa::back::versioned_llvm_target;
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::errors as ssa_errors;
use rustc_codegen_ssa::traits::*;
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@ -38,7 +39,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
use crate::llvm::Metadata;
use crate::type_::Type;
use crate::value::Value;
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
/// However, there are various cx related functions which we want to be available to the builder and
@ -643,7 +644,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
llvm::set_section(g, c"llvm.metadata");
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
assert_eq!(self.type_kind(ty), TypeKind::Function);
unsafe { llvm::LLVMGetReturnType(ty) }
}
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
unsafe { llvm::LLVMGlobalGetValueType(val) }
}
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
common::val_ty(v)
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn new(
llmod: &'ll llvm::Module,
@ -660,6 +672,13 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
}
// FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
// onto a trait that is also implemented for GenericCx.
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
}
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }

View file

@ -4,7 +4,7 @@
use libc::{c_char, c_uint};
use super::MetadataKindId;
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
use crate::llvm::Bool;
#[link(name = "llvm-wrapper", kind = "static")]
@ -17,6 +17,8 @@ unsafe extern "C" {
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
}
unsafe extern "C" {

View file

@ -1180,7 +1180,7 @@ unsafe extern "C" {
// Operations on parameters
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
pub(crate) fn LLVMCountParams(Fn: &Value) -> c_uint;
pub(crate) safe fn LLVMCountParams(Fn: &Value) -> c_uint;
pub(crate) fn LLVMGetParam(Fn: &Value, Index: c_uint) -> &Value;
// Operations on basic blocks