Autodiff batching
Enzyme supports batching, which is especially known from the ML side when training neural networks.
There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights.
That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations.
Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size,
and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of
```rs
for i in 0..100 {
df(x[i], y[i], 1234);
}
```
You can now do
```rs
for i in 0..100.step_by(4) {
df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234);
}
```
which will give the same results, but allows better compiler optimizations. See the testcase for details.
There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days.
I will also add more tests for both modes.
For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU
I'll also add some other docs to the dev guide and user docs in another PR.
r? ghost
Tracking:
- https://github.com/rust-lang/rust/issues/124509
- https://github.com/rust-lang/rust/issues/135283
Expose algebraic floating point intrinsics
# Problem
A stable Rust implementation of a simple dot product is 8x slower than C++ on modern x86-64 CPUs. The root cause is an inability to let the compiler reorder floating point operations for better vectorization.
See https://github.com/calder/dot-bench for benchmarks. Measurements below were performed on a i7-10875H.
### C++: 10us ✅
With Clang 18.1.3 and `-O2 -march=haswell`:
<table>
<tr>
<th>C++</th>
<th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="cc">
float dot(float *a, float *b, size_t len) {
#pragma clang fp reassociate(on)
float sum = 0.0;
for (size_t i = 0; i < len; ++i) {
sum += a[i] * b[i];
}
return sum;
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/739573c0-380a-4d84-9fd9-141343ce7e68" />
</td>
</tr>
</table>
### Nightly Rust: 10us ✅
With rustc 1.86.0-nightly (8239a37f9) and `-C opt-level=3 -C target-feature=+avx2,+fma`:
<table>
<tr>
<th>Rust</th>
<th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="rust">
fn dot(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0;
for i in 0..a.len() {
sum = fadd_algebraic(sum, fmul_algebraic(a[i], b[i]));
}
sum
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/9dcf953a-2cd7-42f3-bc34-7117de4c5fb9" />
</td>
</tr>
</table>
### Stable Rust: 84us ❌
With rustc 1.84.1 (e71f9a9a9) and `-C opt-level=3 -C target-feature=+avx2,+fma`:
<table>
<tr>
<th>Rust</th>
<th>Assembly</th>
</tr>
<tr>
<td>
<pre lang="rust">
fn dot(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
sum
}
</pre>
</td>
<td>
<img src="https://github.com/user-attachments/assets/936a1f7e-33e4-4ff8-a732-c3cdfe068dca" />
</td>
</tr>
</table>
# Proposed Change
Add `core::intrinsics::f*_algebraic` wrappers to `f16`, `f32`, `f64`, and `f128` gated on a new `float_algebraic` feature.
# Alternatives Considered
https://github.com/rust-lang/rust/issues/21690 has a lot of good discussion of various options for supporting fast math in Rust, but is still open a decade later because any choice that opts in more than individual operations is ultimately contrary to Rust's design principles.
In the mean time, processors have evolved and we're leaving major performance on the table by not supporting vectorization. We shouldn't make users choose between an unstable compiler and an 8x performance hit.
# References
* https://github.com/rust-lang/rust/issues/21690
* https://github.com/rust-lang/libs-team/issues/532
* https://github.com/rust-lang/rust/issues/136469
* https://github.com/calder/dot-bench
* https://www.felixcloutier.com/x86/vfmadd132ps:vfmadd213ps:vfmadd231ps
try-job: x86_64-gnu-nopt
try-job: x86_64-gnu-aux
gvn: Invalid dereferences for all non-local mutations
Fixes#132353.
This PR removes the computation value by traversing SSA locals through `for_each_assignment_mut`.
Because the `for_each_assignment_mut` traversal skips statements which have side effects, such as dereference assignments, the computation may be unsound. Instead of `for_each_assignment_mut`, we compute values by traversing in reverse postorder.
Because we compute and use the symbolic representation of values on the fly, I invalidate all old values when encountering a dereference assignment. The current approach does not prevent the optimization of a clone to a copy.
In the future, we may add an alias model, or dominance information for dereference assignments, or SSA form to help GVN.
r? cjgillot
cc `@jieyouxu` #132356
cc `@RalfJung` #133474
slice: Remove some uses of unsafe in first/last chunk methods
Remove unsafe `split_at_unchecked` and `split_at_mut_unchecked` in some slice `split_first_chunk`/`split_last_chunk` methods.
Replace those calls with the safe `split_at` and `split_at_checked` where applicable.
Add codegen tests to check for no panics when calculating the last chunk index using `checked_sub` and `split_at`.
Better viewed with whitespace disabled in diff view
---
The unchecked calls are mostly manual implementations of the safe methods, but with the safety condition negated from `mid <= len` to `len < mid`.
```rust
if self.len() < N {
None
} else {
// SAFETY: We manually verified the bounds of the split.
let (first, tail) = unsafe { self.split_at_unchecked(N) };
// Or for the last_chunk methods
let (init, last) = unsafe { self.split_at_unchecked(self.len() - N) };
```
Unsafe is still needed for the pointer array casts. Their safety comments are unmodified.
PassWrapper: adapt for llvm/llvm-project@94122d58fc77079a291a3d008914…
…006cb509d9db
We also have to remove the LLVM argument in cast-target-abi.rs for LLVM
21. I'm not really sure what the best approach here is since that test already uses revisions. We could also fork the test into a copy for LLVM 19-20 and another for LLVM 21, but what I did for now was drop the lint-abort-on-error flag to LLVM figuring that some coverage was better than none, but I'm happy to change this if that was a bad direction.
r? dianqk
````@rustbot```` label llvm-main
We also have to remove the LLVM argument in cast-target-abi.rs for LLVM
21. I'm not really sure what the best approach here is since that test
already uses revisions. We could also fork the test into a copy for LLVM
19-20 and another for LLVM 21, but what I did for now was drop the
lint-abort-on-error flag to LLVM figuring that some coverage was better
than none, but I'm happy to change this if that was a bad direction.
The above also applies for ffi-out-of-bounds-loads.rs.
r? dianqk
@rustbot label llvm-main
Remove unsafe `split_at_unchecked` and `split_at_mut_unchecked`
in some slice `split_first_chunk`/`split_last_chunk` methods.
Replace those calls with the safe `split_at` and `split_at_checked` where
applicable.
Add codegen tests to check for no panics when calculating the last
chunk index using `checked_sub` and `split_at`
Avoid wrapping constant allocations in packed structs when not necessary
This way LLVM will set the string merging flag if the alloc is a nul terminated string, reducing binary sizes.
try-job: armhf-gnu
Don't produce debug information for compiler-introduced-vars when desugaring assignments.
An assignment such as
(a, b) = (b, c);
desugars to the HIR
{ let (lhs, lhs) = (b, c); a = lhs; b = lhs; };
The repeated `lhs` leads to multiple Locals assigned to the same DILocalVariable. Rather than attempting to fix that, get rid of the debug info for these bindings that don't even exist in the program to begin with.
Fixes#138198
r? `@jieyouxu`
Lower to a memset(undef) when Rvalue::Repeat repeats uninit
Fixes https://github.com/rust-lang/rust/issues/138625.
It is technically correct to just do nothing. But if we actually do nothing, we may miss that this is de-initializing something, so instead we just lower to a single memset that writes undef. This is still superior to the memcpy loop, in both quality of code we hand to the backend and LLVM's final output.
Lower BinOp::Cmp to llvm.{s,u}cmp.* intrinsics
Lowers `mir::BinOp::Cmp` (`three_way_compare` intrinsic) to the corresponding LLVM `llvm.{s,u}cmp.i8.*` intrinsics.
These are the intrinsics mentioned in https://github.com/rust-lang/rust/pull/118310, which are now available in LLVM 19.
I couldn't find any follow-up PRs/discussions about this, please let me know if I missed something.
r? `@scottmcm`
An assignment such as
(a, b) = (b, c);
desugars to the HIR
{ let (lhs, lhs) = (b, c); a = lhs; b = lhs; };
The repeated `lhs` leads to multiple Locals assigned to the same DILocalVariable. Rather than
attempting to fix that, get rid of the debug info for these bindings that don't even exist
in the program to begin with.
Fixes#138198
Some tests expect to be compiled for a specific CPU or require certain
target features to be present (or absent). These tests work fine with
default CPUs but fail in downstream builds for RHEL and Fedora, where
we use non-default CPUs such as z13 on s390x, pwr9 on ppc64le, or
x86-64-v2/x86-64-v3 on x86_64.
Mangle rustc_std_internal_symbols functions
This reduces the risk of issues when using a staticlib or rust dylib compiled with a different rustc version in a rust program. Currently this will either (in the case of staticlib) cause a linker error due to duplicate symbol definitions, or (in the case of rust dylibs) cause rustc_std_internal_symbols functions to be silently overridden. As rust gets more commonly used inside the implementation of libraries consumed with a C interface (like Spidermonkey, Ruby YJIT (curently has to do partial linking of all rust code to hide all symbols not part of the C api), the Rusticl OpenCL implementation in mesa) this is becoming much more of an issue. With this PR the only symbols remaining with an unmangled name are rust_eh_personality (LLVM doesn't allow renaming it) and `__rust_no_alloc_shim_is_unstable`.
Helps mitigate https://github.com/rust-lang/rust/issues/104707
try-job: aarch64-gnu-debug
try-job: aarch64-apple
try-job: x86_64-apple-1
try-job: x86_64-mingw-1
try-job: i686-mingw-1
try-job: x86_64-msvc-1
try-job: i686-msvc-1
try-job: test-various
try-job: armhf-gnu
Emit function declarations for functions with `#[linkage="extern_weak"]`
Currently, when declaring an extern weak function in Rust, we use the following syntax:
```rust
unsafe extern "C" {
#[linkage = "extern_weak"]
static FOO: Option<unsafe extern "C" fn() -> ()>;
}
```
This allows runtime-checking the extern weak symbol through the Option.
When emitting LLVM-IR, the Rust compiler currently emits this static as an i8, and a pointer that is initialized with the value of the global i8 and represents the nullabilty e.g.
```
`@FOO` = extern_weak global i8
`@_rust_extern_with_linkage_FOO` = internal global ptr `@FOO`
```
This approach does not work well with CFI, where we need to attach CFI metadata to a concrete function declaration, which was pointed out in https://github.com/rust-lang/rust/issues/115199.
This change switches to emitting a proper function declaration instead of a global i8. This allows CFI to work for extern_weak functions. Example:
```
`@_rust_extern_with_linkage_FOO` = internal global ptr `@FOO`
...
declare !type !61 !type !62 !type !63 !type !64 extern_weak void `@FOO(double)` unnamed_addr #6
```
We keep initializing the Rust internal symbol with the function declaration, which preserves the correct behavior for runtime checking the Option.
r? `@rcvalle`
cc `@jakos-sec`
try-job: test-various
Currently, when declaring an extern weak function in Rust, we use the
following syntax:
```rust
unsafe extern "C" {
#[linkage = "extern_weak"]
static FOO: Option<unsafe extern "C" fn() -> ()>;
}
```
This allows runtime-checking the extern weak symbol through the Option.
When emitting LLVM-IR, the Rust compiler currently emits this static
as an i8, and a pointer that is initialized with the value of the global
i8 and represents the nullabilty e.g.
```
@FOO = extern_weak global i8
@_rust_extern_with_linkage_FOO = internal global ptr @FOO
```
This approach does not work well with CFI, where we need to attach CFI
metadata to a concrete function declaration, which was pointed out in
https://github.com/rust-lang/rust/issues/115199.
This change switches to emitting a proper function declaration instead
of a global i8. This allows CFI to work for extern_weak functions.
We keep initializing the Rust internal symbol with the function
declaration, which preserves the correct behavior for runtime checking
the Option.
Co-authored-by: Jakob Koschel <jakobkoschel@google.com>
added some new test to check for result and options opt
Apologies for the delay. Finally have some time to get back into contributing.
## Context
- Added some tests to show optimization on result and options for 64 and 128 bits
- Relevant issue https://github.com/rust-lang/rust/issues/101210
## Some newb questions from me
- [x] My local llvm IR has `nuw` in `result_nop_match_128` etc whereas [godbolt version](https://rust.godbolt.org/z/Td9zoT5zn) doesn't have. So I put optional there, but not sure if it's desirable (maybe I'm not using the compiled llvm in the repo). I ran the test with
```bash
./x test tests/codegen/try_question_mark_nop.rs
```
- [x] Unless I'm reading it wrongly, but `option_nop_match_128` and `option_nop_traits_128` look to be **not** optimized away?
Update:
Here's the test for future reference
```rust
// CHECK-LABEL: `@option_nop_match_128`
#[no_mangle]
pub fn option_nop_match_128(x: Option<i128>) -> Option<i128> {
// CHECK: start:
// CHECK-NEXT: %trunc = trunc nuw i128 %0 to i1
// CHECK-NEXT: br i1 %trunc, label %bb3, label %bb4
// CHECK: bb3:
// CHECK-NEXT: %2 = getelementptr inbounds {{(nuw )?}}i8, ptr %_0, i64 16
// CHECK-NEXT: store i128 %1, ptr %2, align 16
// CHECK: bb4:
// CHECK-NEXT: %storemerge = phi i128 [ 1, %bb3 ], [ 0, %start ]
// CHECK-NEXT: store i128 %storemerge, ptr %_0, align 16
// CHECK-NEXT: ret void
match x {
Some(x) => Some(x),
None => None,
}
}
```
r? `@scottmcm`
Allow more top-down inlining for single-BB callees
This means that things like `<usize as Step>::forward_unchecked` and `<PartialOrd for f32>::le` will inline even if
we've already done a bunch of inlining to find the calls to them.
Fixes#138136
~~Draft as it's built atop #138135, which adds a mir-opt test that's a nice demonstration of this. To see just this change, look at <48f63e3be5>~~ Rebased to be just the inlining change, as the other existing tests show it great.
Don't `alloca` just to look at a discriminant
Today we're making LLVM do a bunch of extra work when you match on trivial stuff like `Option<bool>` or `ControlFlow<u8>`.
This PR changes that so that simple types like `Option<u32>` or `Result<(), Box<Error>>` can stay as `OperandValue::ScalarPair` and we can still read the discriminant from them, rather than needing to write them into memory to have a `PlaceValue` just to get the discriminant out.
Fixes#137503
naked functions: on windows emit `.endef` without the symbol name
tracking issue: https://github.com/rust-lang/rust/issues/90957
fixes https://github.com/rust-lang/rust/issues/138320
The `.endef` directive does not take the name as an argument. Apparently the LLVM x86_64 parser does accept this, but on i686 it's rejected. In general `i686` does some special name mangling stuff, so it's good to include it in the naked function tests.
r? ````@ChrisDenton```` (because windows)
This means that things like `<usize as Step>::forward_unchecked` and `<PartialOrd for f32>::le` will inline even if we've already done a bunch of inlining to find the calls to them.
Apply dllimport in ThinLTO
This partially reverts https://github.com/rust-lang/rust/pull/103353 by properly applying `dllimport` if `-Z dylib-lto` is passed. That PR should probably fully be reverted as it looks quite sketchy. We don't know locally if the entire crate graph would be statically linked.
This should hopefully be sufficient to make ThinLTO work for rustc on Windows.
r? ``@wesleywiser``
---
Edit: This PR is changed to just generally revert https://github.com/rust-lang/rust/pull/103353.
Don't re-`assume` in `transmute`s that don't change niches
I noticed in nightly 2025-02-21 that `transmute` is emitting way more `assume`s than necessary for newtypes.
For example, the three transmutes in <https://rust.godbolt.org/z/fW1KaTc4o> emits
```rust
define noundef range(i32 1, 0) i32 `@repeatedly_transparent_transmute(i32` noundef range(i32 1, 0) %_1) unnamed_addr {
start:
%0 = sub i32 %_1, 1
%1 = icmp ule i32 %0, -2
call void `@llvm.assume(i1` %1)
%2 = sub i32 %_1, 1
%3 = icmp ule i32 %2, -2
call void `@llvm.assume(i1` %3)
%4 = sub i32 %_1, 1
%5 = icmp ule i32 %4, -2
call void `@llvm.assume(i1` %5)
%6 = sub i32 %_1, 1
%7 = icmp ule i32 %6, -2
call void `@llvm.assume(i1` %7)
%8 = sub i32 %_1, 1
%9 = icmp ule i32 %8, -2
call void `@llvm.assume(i1` %9)
%10 = sub i32 %_1, 1
%11 = icmp ule i32 %10, -2
call void `@llvm.assume(i1` %11)
ret i32 %_1
}
```
But those are all just newtypes that don't change size or niches, so none of it's needed.
After this PR it's down to just
```rust
define noundef range(i32 1, 0) i32 `@repeatedly_transparent_transmute(i32` noundef range(i32 1, 0) %_1) unnamed_addr {
start:
ret i32 %_1
}
```
because none of those `assume`s in the original actually did anything.
(Transmuting to something with a difference niche, though, still has the assumes -- the other tests continue to pass checking that.)
Don't include global asm in `mir_keys`, fix error body synthesis
r? oli-obk
Fixes#137470Fixes#137471Fixes#137472Fixes#137473
try-job: test-various
try-job: x86_64-apple-2
Break critical edges in inline asm before code generation
An inline asm terminator defines outputs along its target edges -- a
fallthrough target and labeled targets. Code generation implements this
by inserting code directly into the target blocks. This approach works
only if the target blocks don't have other predecessors.
Establish required invariant by extending existing code that breaks
critical edges before code generation.
Fixes#137867.
r? ``@bjorn3``
An inline asm terminator defines outputs along its target edges -- a
fallthrough target and labeled targets. Code generation implements this
by inserting code directly into the target blocks. This approach works
only if the target blocks don't have other predecessors.
Establish required invariant by extending existing code that breaks
critical edges before code generation.