Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching Enzyme supports batching, which is especially known from the ML side when training neural networks. There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights. That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations. Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size, and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of ```rs for i in 0..100 { df(x[i], y[i], 1234); } ``` You can now do ```rs for i in 0..100.step_by(4) { df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234); } ``` which will give the same results, but allows better compiler optimizations. See the testcase for details. There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days. I will also add more tests for both modes. For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU I'll also add some other docs to the dev guide and user docs in another PR. r? ghost Tracking: - https://github.com/rust-lang/rust/issues/124509 - https://github.com/rust-lang/rust/issues/135283
This commit is contained in:
commit
c6bf3a01ef
21 changed files with 728 additions and 234 deletions
|
@ -237,10 +237,12 @@ pub enum AutoDiff {
|
|||
PrintPerf,
|
||||
/// Print intermediate IR generation steps
|
||||
PrintSteps,
|
||||
/// Print the whole module, before running opts.
|
||||
/// Print the module, before running autodiff.
|
||||
PrintModBefore,
|
||||
/// Print the module after Enzyme differentiated everything.
|
||||
/// Print the module after running autodiff.
|
||||
PrintModAfter,
|
||||
/// Print the module after running autodiff and optimizations.
|
||||
PrintModFinal,
|
||||
|
||||
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
|
||||
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
|
||||
|
|
|
@ -711,7 +711,7 @@ mod desc {
|
|||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||
pub(crate) const parse_list_with_polarity: &str =
|
||||
"a comma-separated list of strings, with elements beginning with + or -";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||
pub(crate) const parse_number: &str = "a number";
|
||||
|
@ -1359,6 +1359,7 @@ pub mod parse {
|
|||
"PrintSteps" => AutoDiff::PrintSteps,
|
||||
"PrintModBefore" => AutoDiff::PrintModBefore,
|
||||
"PrintModAfter" => AutoDiff::PrintModAfter,
|
||||
"PrintModFinal" => AutoDiff::PrintModFinal,
|
||||
"LooseTypes" => AutoDiff::LooseTypes,
|
||||
"Inline" => AutoDiff::Inline,
|
||||
_ => {
|
||||
|
@ -2093,6 +2094,7 @@ options! {
|
|||
`=PrintSteps`
|
||||
`=PrintModBefore`
|
||||
`=PrintModAfter`
|
||||
`=PrintModFinal`
|
||||
`=LooseTypes`
|
||||
`=Inline`
|
||||
Multiple options can be combined with commas."),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue