1
Fork 0
rust/compiler
Stuart Cook c6bf3a01ef
Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching

Enzyme supports batching, which is especially known from the ML side when training neural networks.
There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights.
That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations.

Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size,
and then each Dual/Duplicated argument has not one, but N shadow arguments.  So instead of
```rs
for i in 0..100 {
   df(x[i], y[i], 1234);
}
```
You can now do
```rs
for i in 0..100.step_by(4) {
   df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234);
}
```
which will give the same results, but allows better compiler optimizations. See the testcase for details.

There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days.

I will also add more tests for both modes.

For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU
I'll also add some other docs to the dev guide and user docs in another PR.

r? ghost

Tracking:

- https://github.com/rust-lang/rust/issues/124509
- https://github.com/rust-lang/rust/issues/135283
2025-04-05 13:18:13 +11:00
..
rustc Revert "Use workspace lints for crates in compiler/ #138084" 2025-03-10 18:12:47 +08:00
rustc_abi BackendRepr::is_signed: comment why this may panics 2025-03-29 12:21:51 +01:00
rustc_arena Remove #![warn(unreachable_pub)] from all compiler/ crates. 2025-03-11 13:14:21 +11:00
rustc_ast Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_ast_ir Use -Wunused_crate_dependencies for compiler crates. 2025-03-20 08:59:43 +11:00
rustc_ast_lowering Tighten up assignment operator representations. 2025-04-03 10:23:03 +11:00
rustc_ast_passes Rollup merge of #139294 - beetrees:fix-f16-f128-literal-feature-gate, r=fmease 2025-04-03 07:39:08 +02:00
rustc_ast_pretty Tighten up assignment operator representations. 2025-04-03 10:23:03 +11:00
rustc_attr_data_structures add rustc_macro_edition_2021 2025-03-19 17:37:35 +01:00
rustc_attr_parsing Avoid kw::Empty when dealing with rustc_allowed_through_unstable_modules. 2025-03-25 16:48:03 +11:00
rustc_baked_icu_data Add unreachable_pub to RUSTC_LINT_FLAGS for compiler/ crates. 2025-03-11 13:14:21 +11:00
rustc_borrowck Auto merge of #139390 - matthiaskrgr:rollup-l64euwx, r=matthiaskrgr 2025-04-04 23:03:57 +00:00
rustc_builtin_macros Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_codegen_cranelift Auto merge of #139213 - bjorn3:cg_clif_test_coretests, r=jieyouxu 2025-04-04 11:59:59 +00:00
rustc_codegen_gcc Rollup merge of #138949 - madsmtm:rename-to-darwin, r=WaffleLapkin 2025-04-04 08:02:05 +02:00
rustc_codegen_llvm Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_codegen_ssa Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_const_eval Make LevelAndSource a struct 2025-04-03 09:17:55 +00:00
rustc_data_structures Invalidate all dereferences for non-local assignments 2025-04-02 19:58:35 +08:00
rustc_driver Revert "Use workspace lints for crates in compiler/ #138084" 2025-03-10 18:12:47 +08:00
rustc_driver_impl Rollup merge of #138949 - madsmtm:rename-to-darwin, r=WaffleLapkin 2025-04-04 08:02:05 +02:00
rustc_error_codes Avoid kw::Empty when dealing with rustc_allowed_through_unstable_modules. 2025-03-25 16:48:03 +11:00
rustc_error_messages Rollup merge of #138404 - bjorn3:sysroot_handling_cleanup, r=petrochenkov,jieyouxu 2025-03-13 11:28:35 +01:00
rustc_errors Split ExpectationLintId off Level 2025-04-03 09:17:55 +00:00
rustc_expand Remove NtExpr and NtLiteral. 2025-04-02 06:20:35 +11:00
rustc_feature Rollup merge of #139080 - m-ou-se:super-let-gate, r=traviscross 2025-04-03 07:39:05 +02:00
rustc_fluent_macro Remove #![warn(unreachable_pub)] from all compiler/ crates. 2025-03-11 13:14:21 +11:00
rustc_fs_util Revert "Use workspace lints for crates in compiler/ #138084" 2025-03-10 18:12:47 +08:00
rustc_graphviz Remove #![warn(unreachable_pub)] from all compiler/ crates. 2025-03-11 13:14:21 +11:00
rustc_hashes Revert "Use workspace lints for crates in compiler/ #138084" 2025-03-10 18:12:47 +08:00
rustc_hir Auto merge of #120706 - Bryanskiy:leak, r=lcnr 2025-04-04 01:35:52 +00:00
rustc_hir_analysis Auto merge of #139390 - matthiaskrgr:rollup-l64euwx, r=matthiaskrgr 2025-04-04 23:03:57 +00:00
rustc_hir_pretty Tighten up assignment operator representations. 2025-04-03 10:23:03 +11:00
rustc_hir_typeck Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_incremental Auto merge of #138629 - Zoxc:graph-anon-hashmap, r=oli-obk 2025-03-24 15:02:09 +00:00
rustc_index Use {Decodable,Encodable}_NoContext in type_ir 2025-03-15 06:34:36 +00:00
rustc_index_macros Add unreachable_pub to RUSTC_LINT_FLAGS for compiler/ crates. 2025-03-11 13:14:21 +11:00
rustc_infer Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_interface Rollup merge of #138767 - clubby789:check-cfg-bool, r=Urgau 2025-04-03 21:18:30 +02:00
rustc_lexer Revert "Rollup merge of #136355 - GuillaumeGomez:proc-macro_add_value_retrieval_methods, r=Amanieu" 2025-03-18 13:28:56 +01:00
rustc_lint Rollup merge of #138610 - oli-obk:no-sort-hir-ids, r=compiler-errors 2025-04-03 21:18:30 +02:00
rustc_lint_defs impl !PartialOrd for HirId 2025-04-03 09:22:21 +00:00
rustc_llvm Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_log Use -Wunused_crate_dependencies for compiler crates. 2025-03-20 08:59:43 +11:00
rustc_macros Move codec module back into middle 2025-03-15 06:42:48 +00:00
rustc_metadata Rollup merge of #138949 - madsmtm:rename-to-darwin, r=WaffleLapkin 2025-04-04 08:02:05 +02:00
rustc_middle Auto merge of #139390 - matthiaskrgr:rollup-l64euwx, r=matthiaskrgr 2025-04-04 23:03:57 +00:00
rustc_mir_build Rollup merge of #138610 - oli-obk:no-sort-hir-ids, r=compiler-errors 2025-04-03 21:18:30 +02:00
rustc_mir_dataflow Auto merge of #138414 - matthiaskrgr:rollup-9ablqdb, r=matthiaskrgr 2025-03-12 17:27:43 +00:00
rustc_mir_transform Auto merge of #132527 - DianQK:gvn-stmt-iter, r=oli-obk 2025-04-03 19:17:33 +00:00
rustc_monomorphize Make missing optimized MIR error more informative 2025-04-01 09:25:12 +00:00
rustc_next_trait_solver Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_parse Rollup merge of #138017 - nnethercote:tighten-assignment-op, r=spastorino 2025-04-03 21:18:28 +02:00
rustc_parse_format Slim rustc_parse_format dependencies down 2025-03-23 07:30:18 +01:00
rustc_passes impl !PartialOrd for HirId 2025-04-03 09:22:21 +00:00
rustc_pattern_analysis Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_privacy privacy: Visit types and traits in impls in type privacy lints 2025-03-25 12:40:02 +03:00
rustc_query_impl Add a dep kind for use of the anon node with zero dependencies 2025-04-02 07:35:05 +02:00
rustc_query_system Add a dep kind for use of the anon node with zero dependencies 2025-04-02 07:35:05 +02:00
rustc_resolve Rollup merge of #139184 - Urgau:crate-root-lint-levels, r=jieyouxu 2025-04-02 22:52:45 +09:00
rustc_sanitizers Encode synthetic by-move coroutine body with a different DefPathData 2025-03-30 22:53:21 +00:00
rustc_serialize Convert rustc_serialize integration tests to unit tests. 2025-03-20 08:59:50 +11:00
rustc_session Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk 2025-04-05 13:18:13 +11:00
rustc_smir Use -Wunused_crate_dependencies for compiler crates. 2025-03-20 08:59:43 +11:00
rustc_span Auto merge of #120706 - Bryanskiy:leak, r=lcnr 2025-04-04 01:35:52 +00:00
rustc_symbol_mangling Encode synthetic by-move coroutine body with a different DefPathData 2025-03-30 22:53:21 +00:00
rustc_target Auto merge of #137869 - Noratrieb:Now_I_am_become_death,_the_destroyer_of_i686-pc-windows-gnu, r=workingjubilee 2025-04-04 15:45:03 +00:00
rustc_trait_selection Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_traits Rollup merge of #138394 - lcnr:yeet-variant, r=compiler-errors 2025-03-12 10:19:32 -07:00
rustc_transmute Add #[cfg(test)] for Transition in dfa 2025-03-18 07:17:16 +00:00
rustc_ty_utils add TypingMode::Borrowck 2025-04-03 11:13:10 +02:00
rustc_type_ir Auto merge of #138785 - lcnr:typing-mode-borrowck, r=compiler-errors,oli-obk 2025-04-04 19:54:42 +00:00
rustc_type_ir_macros Fold visit into ty 2025-03-15 06:34:36 +00:00
stable_mir use try_fold instead of fold 2025-03-28 12:14:09 +00:00