1
Fork 0

move second opt run to lto phase and cleanup code

This commit is contained in:
Manuel Drehwald 2025-02-10 01:35:22 -05:00
parent 21d096184e
commit 1221cff551
7 changed files with 75 additions and 54 deletions

View file

@ -606,10 +606,31 @@ pub(crate) fn run_pass_manager(
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the // If this rustc version was build with enzyme/autodiff enabled, and if users applied the
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time. // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
let first_run = true;
debug!("running llvm pm opt pipeline"); debug!("running llvm pm opt pipeline");
unsafe { unsafe {
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; write::llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
opt_stage,
write::AutodiffStage::DuringAD,
)?;
}
// FIXME(ZuseZ4): Make this more granular
if cfg!(llvm_enzyme) && !thin {
unsafe {
write::llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
llvm::OptStage::FatLTO,
write::AutodiffStage::PostAD,
)?;
}
} }
debug!("lto done"); debug!("lto done");
Ok(()) Ok(())

View file

@ -530,6 +530,16 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option<CString> {
config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned()) config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
} }
// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling)
// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes.
// PostAD will run all opts, including size increasing opts.
#[derive(Debug, Eq, PartialEq)]
pub(crate) enum AutodiffStage {
PreAD,
DuringAD,
PostAD,
}
pub(crate) unsafe fn llvm_optimize( pub(crate) unsafe fn llvm_optimize(
cgcx: &CodegenContext<LlvmCodegenBackend>, cgcx: &CodegenContext<LlvmCodegenBackend>,
dcx: DiagCtxtHandle<'_>, dcx: DiagCtxtHandle<'_>,
@ -537,7 +547,7 @@ pub(crate) unsafe fn llvm_optimize(
config: &ModuleConfig, config: &ModuleConfig,
opt_level: config::OptLevel, opt_level: config::OptLevel,
opt_stage: llvm::OptStage, opt_stage: llvm::OptStage,
skip_size_increasing_opts: bool, autodiff_stage: AutodiffStage,
) -> Result<(), FatalError> { ) -> Result<(), FatalError> {
// Enzyme: // Enzyme:
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
@ -550,13 +560,16 @@ pub(crate) unsafe fn llvm_optimize(
let unroll_loops; let unroll_loops;
let vectorize_slp; let vectorize_slp;
let vectorize_loop; let vectorize_loop;
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
let run_enzyme = cfg!(llvm_enzyme);
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly, // optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
// We therefore have two calls to llvm_optimize, if autodiff is used.
//
// FIXME(ZuseZ4): Before shipping on nightly,
// we should make this more granular, or at least check that the user has at least one autodiff // we should make this more granular, or at least check that the user has at least one autodiff
// call in their code, to justify altering the compilation pipeline. // call in their code, to justify altering the compilation pipeline.
if skip_size_increasing_opts && run_enzyme { if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
unroll_loops = false; unroll_loops = false;
vectorize_slp = false; vectorize_slp = false;
vectorize_loop = false; vectorize_loop = false;
@ -566,7 +579,7 @@ pub(crate) unsafe fn llvm_optimize(
vectorize_slp = config.vectorize_slp; vectorize_slp = config.vectorize_slp;
vectorize_loop = config.vectorize_loop; vectorize_loop = config.vectorize_loop;
} }
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop); trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme);
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
let pgo_gen_path = get_pgo_gen_path(config); let pgo_gen_path = get_pgo_gen_path(config);
let pgo_use_path = get_pgo_use_path(config); let pgo_use_path = get_pgo_use_path(config);
@ -686,18 +699,14 @@ pub(crate) unsafe fn optimize(
_ => llvm::OptStage::PreLinkNoLTO, _ => llvm::OptStage::PreLinkNoLTO,
}; };
// If we know that we will later run AD, then we disable vectorization and loop unrolling // If we know that we will later run AD, then we disable vectorization and loop unrolling.
let skip_size_increasing_opts = cfg!(llvm_enzyme); // Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
// usages, not just if we build rustc with autodiff support.
let autodiff_stage =
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
return unsafe { return unsafe {
llvm_optimize( llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
cgcx,
dcx,
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)
}; };
} }
Ok(()) Ok(())

View file

@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_errors::FatalError; use rustc_errors::FatalError;
use rustc_session::config::Lto;
use tracing::{debug, trace}; use tracing::{debug, trace};
use crate::back::write::{llvm_err, llvm_optimize}; use crate::back::write::llvm_err;
use crate::builder::SBuilder; use crate::builder::SBuilder;
use crate::context::SimpleCx; use crate::context::SimpleCx;
use crate::declare::declare_simple_fn; use crate::declare::declare_simple_fn;
@ -153,7 +152,7 @@ fn generate_enzyme_call<'ll>(
_ => {} _ => {}
} }
trace!("matching autodiff arguments"); debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level // We now handle the issue that Rust level arguments not always match the llvm-ir level
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@ -222,7 +221,10 @@ fn generate_enzyme_call<'ll>(
// A duplicated pointer will have the following two outer_fn arguments: // A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...). // (..., metadata! enzyme_dup, ptr, ptr, ...).
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) { if matches!(
diff_activity,
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
) {
assert!( assert!(
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
); );
@ -282,7 +284,7 @@ pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>, module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>, cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>, diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig, _config: &ModuleConfig,
) -> Result<(), FatalError> { ) -> Result<(), FatalError> {
for item in &diff_items { for item in &diff_items {
trace!("{}", item); trace!("{}", item);
@ -317,29 +319,6 @@ pub(crate) fn differentiate<'ll>(
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};
// This is our second opt call, so now we run all opts,
// to make sure we get the best performance.
let skip_size_increasing_opts = false;
trace!("running Module Optimization after differentiation");
unsafe {
llvm_optimize(
cgcx,
diag_handler.handle(),
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)?
};
}
trace!("done with differentiate()"); trace!("done with differentiate()");
Ok(()) Ok(())

View file

@ -193,6 +193,10 @@ fn main() {
cfg.define(&flag, None); cfg.define(&flag, None);
} }
if tracked_env_var_os("LLVM_ENZYME").is_some() {
cfg.define("ENZYME", None);
}
if tracked_env_var_os("LLVM_RUSTLLVM").is_some() { if tracked_env_var_os("LLVM_RUSTLLVM").is_some() {
cfg.define("LLVM_RUSTLLVM", None); cfg.define("LLVM_RUSTLLVM", None);
} }

View file

@ -689,7 +689,9 @@ struct LLVMRustSanitizerOptions {
}; };
// This symbol won't be available or used when Enzyme is not enabled // This symbol won't be available or used when Enzyme is not enabled
extern "C" void registerEnzyme(llvm::PassBuilder &PB) __attribute__((weak)); #ifdef ENZYME
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
#endif
extern "C" LLVMRustResult LLVMRustOptimize( extern "C" LLVMRustResult LLVMRustOptimize(
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef, LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
@ -697,8 +699,9 @@ extern "C" LLVMRustResult LLVMRustOptimize(
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR, bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops, bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
bool EmitLifetimeMarkers, bool RunEnzyme, LLVMRustSanitizerOptions *SanitizerOptions, bool EmitLifetimeMarkers, bool RunEnzyme,
const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
const char *PGOUsePath, bool InstrumentCoverage,
const char *InstrProfileOutput, const char *PGOSampleUsePath, const char *InstrProfileOutput, const char *PGOSampleUsePath,
bool DebugInfoForProfiling, void *LlvmSelfProfiler, bool DebugInfoForProfiling, void *LlvmSelfProfiler,
LLVMRustSelfProfileBeforePassCallback BeforePassCallback, LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
@ -1014,6 +1017,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
} }
// now load "-enzyme" pass: // now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) { if (RunEnzyme) {
registerEnzyme(PB); registerEnzyme(PB);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
@ -1022,6 +1026,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
return LLVMRustResult::Failure; return LLVMRustResult::Failure;
} }
} }
#endif
// Upgrade all calls to old intrinsics first. // Upgrade all calls to old intrinsics first.
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;) for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

View file

@ -1049,9 +1049,9 @@ pub fn rustc_cargo(
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>. // <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
cargo.rustflag("-Zon-broken-pipe=kill"); cargo.rustflag("-Zon-broken-pipe=kill");
// We temporarily disable linking here as part of some refactoring. // We want to link against registerEnzyme and in the future we want to use additional
// This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now. // functionality from Enzyme core. For that we need to link against Enzyme.
// In a follow-up PR, we will re-enable linking here and load the pass for them. // FIXME(ZuseZ4): Get the LLVM version number automatically instead of hardcoding it.
if builder.config.llvm_enzyme { if builder.config.llvm_enzyme {
cargo.rustflag("-l").rustflag("Enzyme-19"); cargo.rustflag("-l").rustflag("Enzyme-19");
} }
@ -1234,6 +1234,9 @@ fn rustc_llvm_env(builder: &Builder<'_>, cargo: &mut Cargo, target: TargetSelect
if builder.is_rust_llvm(target) { if builder.is_rust_llvm(target) {
cargo.env("LLVM_RUSTLLVM", "1"); cargo.env("LLVM_RUSTLLVM", "1");
} }
if builder.config.llvm_enzyme {
cargo.env("LLVM_ENZYME", "1");
}
let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target }); let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target });
cargo.env("LLVM_CONFIG", &llvm_config); cargo.env("LLVM_CONFIG", &llvm_config);

View file

@ -15,9 +15,9 @@ fn square(x: &f64) -> f64 {
// CHECK-NEXT:invertstart: // CHECK-NEXT:invertstart:
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val // CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val // CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819 // CHECK-NEXT: %1 = load double, ptr %"x'", align 8
// CHECK-NEXT: %2 = fadd fast double %1, %0 // CHECK-NEXT: %2 = fadd fast double %1, %0
// CHECK-NEXT: store double %2, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819 // CHECK-NEXT: store double %2, ptr %"x'", align 8
// CHECK-NEXT: ret double %_0 // CHECK-NEXT: ret double %_0
// CHECK-NEXT:} // CHECK-NEXT:}