1
Fork 0

Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk

Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle

This PR should not be merged until the rustc_codegen_llvm part is merged.
I will also alter it a little based on what get's shaved off from the cg_llvm PR,
and address some of the feedback I received in the other PR (including cleanups).

I am putting it already up to
1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and
2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost.

Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please?
For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook

Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: https://github.com/EnzymeAD/rust/issues/173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me.

Tracking:

- https://github.com/rust-lang/rust/issues/124509

r? `@jieyouxu`
This commit is contained in:
Jacob Pratt 2025-01-31 00:26:30 -05:00 committed by GitHub
commit c19c4b91f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 482 additions and 38 deletions

View file

@ -398,6 +398,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
@ -1029,6 +1030,38 @@ pub mod parse {
}
}
pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
let Some(v) = v else {
*slot = vec![];
return true;
};
let mut v: Vec<&str> = v.split(",").collect();
v.sort_unstable();
for &val in v.iter() {
let variant = match val {
"PrintTA" => AutoDiff::PrintTA,
"PrintAA" => AutoDiff::PrintAA,
"PrintPerf" => AutoDiff::PrintPerf,
"Print" => AutoDiff::Print,
"PrintModBefore" => AutoDiff::PrintModBefore,
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
"LooseTypes" => AutoDiff::LooseTypes,
"NoModOptAfter" => AutoDiff::NoModOptAfter,
"EnableFncOpt" => AutoDiff::EnableFncOpt,
"NoVecUnroll" => AutoDiff::NoVecUnroll,
"Inline" => AutoDiff::Inline,
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
return false;
}
};
slot.push(variant);
}
true
}
pub(crate) fn parse_instrument_coverage(
slot: &mut InstrumentCoverage,
v: Option<&str>,
@ -1736,6 +1769,22 @@ options! {
either `loaded` or `not-loaded`."),
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
"make cfg(version) treat the current version as incomplete (default: no)"),
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
"a list of optional autodiff flags to enable
Optional extra settings:
`=PrintTA`
`=PrintAA`
`=PrintPerf`
`=Print`
`=PrintModBefore`
`=PrintModAfterOpts`
`=PrintModAfterEnzyme`
`=LooseTypes`
`=NoModOptAfter`
`=EnableFncOpt`
`=NoVecUnroll`
`=Inline`
Multiple options can be combined with commas."),
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \