move second opt run to lto phase and cleanup code
This commit is contained in:
parent
21d096184e
commit
1221cff551
7 changed files with 75 additions and 54 deletions
|
@ -606,10 +606,31 @@ pub(crate) fn run_pass_manager(
|
||||||
|
|
||||||
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
|
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
|
||||||
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
|
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
|
||||||
let first_run = true;
|
|
||||||
debug!("running llvm pm opt pipeline");
|
debug!("running llvm pm opt pipeline");
|
||||||
unsafe {
|
unsafe {
|
||||||
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
|
write::llvm_optimize(
|
||||||
|
cgcx,
|
||||||
|
dcx,
|
||||||
|
module,
|
||||||
|
config,
|
||||||
|
opt_level,
|
||||||
|
opt_stage,
|
||||||
|
write::AutodiffStage::DuringAD,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
// FIXME(ZuseZ4): Make this more granular
|
||||||
|
if cfg!(llvm_enzyme) && !thin {
|
||||||
|
unsafe {
|
||||||
|
write::llvm_optimize(
|
||||||
|
cgcx,
|
||||||
|
dcx,
|
||||||
|
module,
|
||||||
|
config,
|
||||||
|
opt_level,
|
||||||
|
llvm::OptStage::FatLTO,
|
||||||
|
write::AutodiffStage::PostAD,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
debug!("lto done");
|
debug!("lto done");
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -530,6 +530,16 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option<CString> {
|
||||||
config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
|
config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling)
|
||||||
|
// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes.
|
||||||
|
// PostAD will run all opts, including size increasing opts.
|
||||||
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
|
pub(crate) enum AutodiffStage {
|
||||||
|
PreAD,
|
||||||
|
DuringAD,
|
||||||
|
PostAD,
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn llvm_optimize(
|
pub(crate) unsafe fn llvm_optimize(
|
||||||
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
||||||
dcx: DiagCtxtHandle<'_>,
|
dcx: DiagCtxtHandle<'_>,
|
||||||
|
@ -537,7 +547,7 @@ pub(crate) unsafe fn llvm_optimize(
|
||||||
config: &ModuleConfig,
|
config: &ModuleConfig,
|
||||||
opt_level: config::OptLevel,
|
opt_level: config::OptLevel,
|
||||||
opt_stage: llvm::OptStage,
|
opt_stage: llvm::OptStage,
|
||||||
skip_size_increasing_opts: bool,
|
autodiff_stage: AutodiffStage,
|
||||||
) -> Result<(), FatalError> {
|
) -> Result<(), FatalError> {
|
||||||
// Enzyme:
|
// Enzyme:
|
||||||
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
|
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
|
||||||
|
@ -550,13 +560,16 @@ pub(crate) unsafe fn llvm_optimize(
|
||||||
let unroll_loops;
|
let unroll_loops;
|
||||||
let vectorize_slp;
|
let vectorize_slp;
|
||||||
let vectorize_loop;
|
let vectorize_loop;
|
||||||
|
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
|
||||||
|
|
||||||
let run_enzyme = cfg!(llvm_enzyme);
|
|
||||||
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
|
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
|
||||||
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
|
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
|
||||||
|
// We therefore have two calls to llvm_optimize, if autodiff is used.
|
||||||
|
//
|
||||||
|
// FIXME(ZuseZ4): Before shipping on nightly,
|
||||||
// we should make this more granular, or at least check that the user has at least one autodiff
|
// we should make this more granular, or at least check that the user has at least one autodiff
|
||||||
// call in their code, to justify altering the compilation pipeline.
|
// call in their code, to justify altering the compilation pipeline.
|
||||||
if skip_size_increasing_opts && run_enzyme {
|
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
|
||||||
unroll_loops = false;
|
unroll_loops = false;
|
||||||
vectorize_slp = false;
|
vectorize_slp = false;
|
||||||
vectorize_loop = false;
|
vectorize_loop = false;
|
||||||
|
@ -566,7 +579,7 @@ pub(crate) unsafe fn llvm_optimize(
|
||||||
vectorize_slp = config.vectorize_slp;
|
vectorize_slp = config.vectorize_slp;
|
||||||
vectorize_loop = config.vectorize_loop;
|
vectorize_loop = config.vectorize_loop;
|
||||||
}
|
}
|
||||||
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop);
|
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme);
|
||||||
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
|
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
|
||||||
let pgo_gen_path = get_pgo_gen_path(config);
|
let pgo_gen_path = get_pgo_gen_path(config);
|
||||||
let pgo_use_path = get_pgo_use_path(config);
|
let pgo_use_path = get_pgo_use_path(config);
|
||||||
|
@ -686,18 +699,14 @@ pub(crate) unsafe fn optimize(
|
||||||
_ => llvm::OptStage::PreLinkNoLTO,
|
_ => llvm::OptStage::PreLinkNoLTO,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If we know that we will later run AD, then we disable vectorization and loop unrolling
|
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
|
||||||
let skip_size_increasing_opts = cfg!(llvm_enzyme);
|
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
|
||||||
|
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
|
||||||
|
// usages, not just if we build rustc with autodiff support.
|
||||||
|
let autodiff_stage =
|
||||||
|
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
|
||||||
return unsafe {
|
return unsafe {
|
||||||
llvm_optimize(
|
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
|
||||||
cgcx,
|
|
||||||
dcx,
|
|
||||||
module,
|
|
||||||
config,
|
|
||||||
opt_level,
|
|
||||||
opt_stage,
|
|
||||||
skip_size_increasing_opts,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
|
||||||
use rustc_codegen_ssa::ModuleCodegen;
|
use rustc_codegen_ssa::ModuleCodegen;
|
||||||
use rustc_codegen_ssa::back::write::ModuleConfig;
|
use rustc_codegen_ssa::back::write::ModuleConfig;
|
||||||
use rustc_errors::FatalError;
|
use rustc_errors::FatalError;
|
||||||
use rustc_session::config::Lto;
|
|
||||||
use tracing::{debug, trace};
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
use crate::back::write::{llvm_err, llvm_optimize};
|
use crate::back::write::llvm_err;
|
||||||
use crate::builder::SBuilder;
|
use crate::builder::SBuilder;
|
||||||
use crate::context::SimpleCx;
|
use crate::context::SimpleCx;
|
||||||
use crate::declare::declare_simple_fn;
|
use crate::declare::declare_simple_fn;
|
||||||
|
@ -153,7 +152,7 @@ fn generate_enzyme_call<'ll>(
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
trace!("matching autodiff arguments");
|
debug!("matching autodiff arguments");
|
||||||
// We now handle the issue that Rust level arguments not always match the llvm-ir level
|
// We now handle the issue that Rust level arguments not always match the llvm-ir level
|
||||||
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
|
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
|
||||||
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
|
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
|
||||||
|
@ -222,7 +221,10 @@ fn generate_enzyme_call<'ll>(
|
||||||
// A duplicated pointer will have the following two outer_fn arguments:
|
// A duplicated pointer will have the following two outer_fn arguments:
|
||||||
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
|
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
|
||||||
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
// (..., metadata! enzyme_dup, ptr, ptr, ...).
|
||||||
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
|
if matches!(
|
||||||
|
diff_activity,
|
||||||
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
|
||||||
|
) {
|
||||||
assert!(
|
assert!(
|
||||||
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
|
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
|
||||||
);
|
);
|
||||||
|
@ -282,7 +284,7 @@ pub(crate) fn differentiate<'ll>(
|
||||||
module: &'ll ModuleCodegen<ModuleLlvm>,
|
module: &'ll ModuleCodegen<ModuleLlvm>,
|
||||||
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
||||||
diff_items: Vec<AutoDiffItem>,
|
diff_items: Vec<AutoDiffItem>,
|
||||||
config: &ModuleConfig,
|
_config: &ModuleConfig,
|
||||||
) -> Result<(), FatalError> {
|
) -> Result<(), FatalError> {
|
||||||
for item in &diff_items {
|
for item in &diff_items {
|
||||||
trace!("{}", item);
|
trace!("{}", item);
|
||||||
|
@ -317,29 +319,6 @@ pub(crate) fn differentiate<'ll>(
|
||||||
|
|
||||||
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
|
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
|
||||||
|
|
||||||
if let Some(opt_level) = config.opt_level {
|
|
||||||
let opt_stage = match cgcx.lto {
|
|
||||||
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
|
|
||||||
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
|
|
||||||
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
|
|
||||||
_ => llvm::OptStage::PreLinkNoLTO,
|
|
||||||
};
|
|
||||||
// This is our second opt call, so now we run all opts,
|
|
||||||
// to make sure we get the best performance.
|
|
||||||
let skip_size_increasing_opts = false;
|
|
||||||
trace!("running Module Optimization after differentiation");
|
|
||||||
unsafe {
|
|
||||||
llvm_optimize(
|
|
||||||
cgcx,
|
|
||||||
diag_handler.handle(),
|
|
||||||
module,
|
|
||||||
config,
|
|
||||||
opt_level,
|
|
||||||
opt_stage,
|
|
||||||
skip_size_increasing_opts,
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
}
|
|
||||||
trace!("done with differentiate()");
|
trace!("done with differentiate()");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -193,6 +193,10 @@ fn main() {
|
||||||
cfg.define(&flag, None);
|
cfg.define(&flag, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tracked_env_var_os("LLVM_ENZYME").is_some() {
|
||||||
|
cfg.define("ENZYME", None);
|
||||||
|
}
|
||||||
|
|
||||||
if tracked_env_var_os("LLVM_RUSTLLVM").is_some() {
|
if tracked_env_var_os("LLVM_RUSTLLVM").is_some() {
|
||||||
cfg.define("LLVM_RUSTLLVM", None);
|
cfg.define("LLVM_RUSTLLVM", None);
|
||||||
}
|
}
|
||||||
|
|
|
@ -689,7 +689,9 @@ struct LLVMRustSanitizerOptions {
|
||||||
};
|
};
|
||||||
|
|
||||||
// This symbol won't be available or used when Enzyme is not enabled
|
// This symbol won't be available or used when Enzyme is not enabled
|
||||||
extern "C" void registerEnzyme(llvm::PassBuilder &PB) __attribute__((weak));
|
#ifdef ENZYME
|
||||||
|
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
|
||||||
|
#endif
|
||||||
|
|
||||||
extern "C" LLVMRustResult LLVMRustOptimize(
|
extern "C" LLVMRustResult LLVMRustOptimize(
|
||||||
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
|
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
|
||||||
|
@ -697,8 +699,9 @@ extern "C" LLVMRustResult LLVMRustOptimize(
|
||||||
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
|
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
|
||||||
bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
|
bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
|
||||||
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
|
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
|
||||||
bool EmitLifetimeMarkers, bool RunEnzyme, LLVMRustSanitizerOptions *SanitizerOptions,
|
bool EmitLifetimeMarkers, bool RunEnzyme,
|
||||||
const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage,
|
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
|
||||||
|
const char *PGOUsePath, bool InstrumentCoverage,
|
||||||
const char *InstrProfileOutput, const char *PGOSampleUsePath,
|
const char *InstrProfileOutput, const char *PGOSampleUsePath,
|
||||||
bool DebugInfoForProfiling, void *LlvmSelfProfiler,
|
bool DebugInfoForProfiling, void *LlvmSelfProfiler,
|
||||||
LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
|
LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
|
||||||
|
@ -1014,6 +1017,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
|
||||||
}
|
}
|
||||||
|
|
||||||
// now load "-enzyme" pass:
|
// now load "-enzyme" pass:
|
||||||
|
#ifdef ENZYME
|
||||||
if (RunEnzyme) {
|
if (RunEnzyme) {
|
||||||
registerEnzyme(PB);
|
registerEnzyme(PB);
|
||||||
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
|
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
|
||||||
|
@ -1022,6 +1026,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
|
||||||
return LLVMRustResult::Failure;
|
return LLVMRustResult::Failure;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Upgrade all calls to old intrinsics first.
|
// Upgrade all calls to old intrinsics first.
|
||||||
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
|
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
|
||||||
|
|
|
@ -1049,9 +1049,9 @@ pub fn rustc_cargo(
|
||||||
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
|
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
|
||||||
cargo.rustflag("-Zon-broken-pipe=kill");
|
cargo.rustflag("-Zon-broken-pipe=kill");
|
||||||
|
|
||||||
// We temporarily disable linking here as part of some refactoring.
|
// We want to link against registerEnzyme and in the future we want to use additional
|
||||||
// This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now.
|
// functionality from Enzyme core. For that we need to link against Enzyme.
|
||||||
// In a follow-up PR, we will re-enable linking here and load the pass for them.
|
// FIXME(ZuseZ4): Get the LLVM version number automatically instead of hardcoding it.
|
||||||
if builder.config.llvm_enzyme {
|
if builder.config.llvm_enzyme {
|
||||||
cargo.rustflag("-l").rustflag("Enzyme-19");
|
cargo.rustflag("-l").rustflag("Enzyme-19");
|
||||||
}
|
}
|
||||||
|
@ -1234,6 +1234,9 @@ fn rustc_llvm_env(builder: &Builder<'_>, cargo: &mut Cargo, target: TargetSelect
|
||||||
if builder.is_rust_llvm(target) {
|
if builder.is_rust_llvm(target) {
|
||||||
cargo.env("LLVM_RUSTLLVM", "1");
|
cargo.env("LLVM_RUSTLLVM", "1");
|
||||||
}
|
}
|
||||||
|
if builder.config.llvm_enzyme {
|
||||||
|
cargo.env("LLVM_ENZYME", "1");
|
||||||
|
}
|
||||||
let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target });
|
let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target });
|
||||||
cargo.env("LLVM_CONFIG", &llvm_config);
|
cargo.env("LLVM_CONFIG", &llvm_config);
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,9 @@ fn square(x: &f64) -> f64 {
|
||||||
// CHECK-NEXT:invertstart:
|
// CHECK-NEXT:invertstart:
|
||||||
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
|
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
|
||||||
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
|
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
|
||||||
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
|
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
|
||||||
// CHECK-NEXT: %2 = fadd fast double %1, %0
|
// CHECK-NEXT: %2 = fadd fast double %1, %0
|
||||||
// CHECK-NEXT: store double %2, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
|
// CHECK-NEXT: store double %2, ptr %"x'", align 8
|
||||||
// CHECK-NEXT: ret double %_0
|
// CHECK-NEXT: ret double %_0
|
||||||
// CHECK-NEXT:}
|
// CHECK-NEXT:}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue