add new flag to print the module post-AD, before opts
This commit is contained in:
parent
79e17bc71e
commit
89d8948835
3 changed files with 17 additions and 5 deletions
|
@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
|
||||||
}
|
}
|
||||||
// We handle this below
|
// We handle this below
|
||||||
config::AutoDiff::PrintModAfter => {}
|
config::AutoDiff::PrintModAfter => {}
|
||||||
|
// We handle this below
|
||||||
|
config::AutoDiff::PrintModFinal => {}
|
||||||
// This is required and already checked
|
// This is required and already checked
|
||||||
config::AutoDiff::Enable => {}
|
config::AutoDiff::Enable => {}
|
||||||
}
|
}
|
||||||
|
@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg!(llvm_enzyme) && enable_ad {
|
if cfg!(llvm_enzyme) && enable_ad {
|
||||||
|
// This is the post-autodiff IR, mainly used for testing and educational purposes.
|
||||||
|
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
||||||
|
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||||
|
}
|
||||||
|
|
||||||
let opt_stage = llvm::OptStage::FatLTO;
|
let opt_stage = llvm::OptStage::FatLTO;
|
||||||
let stage = write::AutodiffStage::PostAD;
|
let stage = write::AutodiffStage::PostAD;
|
||||||
unsafe {
|
unsafe {
|
||||||
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is the final IR, so people should be able to inspect the optimized autodiff output.
|
// This is the final IR, so people should be able to inspect the optimized autodiff output,
|
||||||
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
// for manual inspection.
|
||||||
|
if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
|
||||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -235,10 +235,12 @@ pub enum AutoDiff {
|
||||||
PrintPerf,
|
PrintPerf,
|
||||||
/// Print intermediate IR generation steps
|
/// Print intermediate IR generation steps
|
||||||
PrintSteps,
|
PrintSteps,
|
||||||
/// Print the whole module, before running opts.
|
/// Print the module, before running autodiff.
|
||||||
PrintModBefore,
|
PrintModBefore,
|
||||||
/// Print the module after Enzyme differentiated everything.
|
/// Print the module after running autodiff.
|
||||||
PrintModAfter,
|
PrintModAfter,
|
||||||
|
/// Print the module after running autodiff and optimizations.
|
||||||
|
PrintModFinal,
|
||||||
|
|
||||||
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
|
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
|
||||||
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
|
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
|
||||||
|
|
|
@ -707,7 +707,7 @@ mod desc {
|
||||||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||||
pub(crate) const parse_list_with_polarity: &str =
|
pub(crate) const parse_list_with_polarity: &str =
|
||||||
"a comma-separated list of strings, with elements beginning with + or -";
|
"a comma-separated list of strings, with elements beginning with + or -";
|
||||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
|
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
|
||||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||||
pub(crate) const parse_number: &str = "a number";
|
pub(crate) const parse_number: &str = "a number";
|
||||||
|
@ -1355,6 +1355,7 @@ pub mod parse {
|
||||||
"PrintSteps" => AutoDiff::PrintSteps,
|
"PrintSteps" => AutoDiff::PrintSteps,
|
||||||
"PrintModBefore" => AutoDiff::PrintModBefore,
|
"PrintModBefore" => AutoDiff::PrintModBefore,
|
||||||
"PrintModAfter" => AutoDiff::PrintModAfter,
|
"PrintModAfter" => AutoDiff::PrintModAfter,
|
||||||
|
"PrintModFinal" => AutoDiff::PrintModFinal,
|
||||||
"LooseTypes" => AutoDiff::LooseTypes,
|
"LooseTypes" => AutoDiff::LooseTypes,
|
||||||
"Inline" => AutoDiff::Inline,
|
"Inline" => AutoDiff::Inline,
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -2088,6 +2089,7 @@ options! {
|
||||||
`=PrintSteps`
|
`=PrintSteps`
|
||||||
`=PrintModBefore`
|
`=PrintModBefore`
|
||||||
`=PrintModAfter`
|
`=PrintModAfter`
|
||||||
|
`=PrintModFinal`
|
||||||
`=LooseTypes`
|
`=LooseTypes`
|
||||||
`=Inline`
|
`=Inline`
|
||||||
Multiple options can be combined with commas."),
|
Multiple options can be combined with commas."),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue