update autodiff flags
This commit is contained in:
parent
161a4bf6ff
commit
e2d250c3f6
11 changed files with 204 additions and 76 deletions
|
@ -1,3 +1,4 @@
|
|||
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
|
||||
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
|
||||
|
||||
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
|
||||
|
|
|
@ -586,6 +586,42 @@ fn thin_lto(
|
|||
}
|
||||
}
|
||||
|
||||
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
|
||||
for &val in ad {
|
||||
match val {
|
||||
config::AutoDiff::PrintModBefore => {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
config::AutoDiff::PrintPerf => {
|
||||
llvm::set_print_perf(true);
|
||||
}
|
||||
config::AutoDiff::PrintAA => {
|
||||
llvm::set_print_activity(true);
|
||||
}
|
||||
config::AutoDiff::PrintTA => {
|
||||
llvm::set_print_type(true);
|
||||
}
|
||||
config::AutoDiff::Inline => {
|
||||
llvm::set_inline(true);
|
||||
}
|
||||
config::AutoDiff::LooseTypes => {
|
||||
llvm::set_loose_types(false);
|
||||
}
|
||||
config::AutoDiff::PrintSteps => {
|
||||
llvm::set_print(true);
|
||||
}
|
||||
// We handle this below
|
||||
config::AutoDiff::PrintModAfter => {}
|
||||
// This is required and already checked
|
||||
config::AutoDiff::Enable => {}
|
||||
}
|
||||
}
|
||||
// This helps with handling enums for now.
|
||||
llvm::set_strict_aliasing(false);
|
||||
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
|
||||
llvm::set_rust_rules(true);
|
||||
}
|
||||
|
||||
pub(crate) fn run_pass_manager(
|
||||
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
||||
dcx: DiagCtxtHandle<'_>,
|
||||
|
@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
|
|||
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
|
||||
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
|
||||
|
||||
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
|
||||
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
|
||||
debug!("running llvm pm opt pipeline");
|
||||
unsafe {
|
||||
write::llvm_optimize(
|
||||
cgcx,
|
||||
dcx,
|
||||
module,
|
||||
config,
|
||||
opt_level,
|
||||
opt_stage,
|
||||
write::AutodiffStage::DuringAD,
|
||||
)?;
|
||||
// The PostAD behavior is the same that we would have if no autodiff was used.
|
||||
// It will run the default optimization pipeline. If AD is enabled we select
|
||||
// the DuringAD stage, which will disable vectorization and loop unrolling, and
|
||||
// schedule two autodiff optimization + differentiation passes.
|
||||
// We then run the llvm_optimize function a second time, to optimize the code which we generated
|
||||
// in the enzyme differentiation pass.
|
||||
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
|
||||
let stage =
|
||||
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
|
||||
|
||||
if enable_ad {
|
||||
enable_autodiff_settings(&config.autodiff, module);
|
||||
}
|
||||
// FIXME(ZuseZ4): Make this more granular
|
||||
if cfg!(llvm_enzyme) && !thin {
|
||||
|
||||
unsafe {
|
||||
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
|
||||
}
|
||||
|
||||
if cfg!(llvm_enzyme) && enable_ad {
|
||||
let opt_stage = llvm::OptStage::FatLTO;
|
||||
let stage = write::AutodiffStage::PostAD;
|
||||
unsafe {
|
||||
write::llvm_optimize(
|
||||
cgcx,
|
||||
dcx,
|
||||
module,
|
||||
config,
|
||||
opt_level,
|
||||
llvm::OptStage::FatLTO,
|
||||
write::AutodiffStage::PostAD,
|
||||
)?;
|
||||
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
|
||||
}
|
||||
|
||||
// This is the final IR, so people should be able to inspect the optimized autodiff output.
|
||||
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
|
||||
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
|
||||
}
|
||||
}
|
||||
|
||||
debug!("lto done");
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::back::write::llvm_err;
|
|||
use crate::builder::SBuilder;
|
||||
use crate::context::SimpleCx;
|
||||
use crate::declare::declare_simple_fn;
|
||||
use crate::errors::LlvmError;
|
||||
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
|
||||
use crate::llvm::AttributePlace::Function;
|
||||
use crate::llvm::{Metadata, True};
|
||||
use crate::value::Value;
|
||||
|
@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
|
|||
let output = attrs.ret_activity;
|
||||
|
||||
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
|
||||
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
|
||||
// it will handle higher-order derivatives correctly automatically (in theory). Currently
|
||||
// higher-order derivatives fail, so we should debug that before adjusting this code.
|
||||
let mut ad_name: String = match attrs.mode {
|
||||
DiffMode::Forward => "__enzyme_fwddiff",
|
||||
DiffMode::Reverse => "__enzyme_autodiff",
|
||||
|
@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
|
|||
let diag_handler = cgcx.create_dcx();
|
||||
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
|
||||
|
||||
// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
|
||||
if !diff_items.is_empty()
|
||||
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
|
||||
{
|
||||
let dcx = cgcx.create_dcx();
|
||||
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
|
||||
}
|
||||
|
||||
// Before dumping the module, we want all the TypeTrees to become part of the module.
|
||||
for item in diff_items.iter() {
|
||||
let name = item.source.clone();
|
||||
|
|
|
@ -92,9 +92,12 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
|
|||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(codegen_llvm_autodiff_without_lto)]
|
||||
#[note]
|
||||
pub(crate) struct AutoDiffWithoutLTO;
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(codegen_llvm_autodiff_without_enable)]
|
||||
pub(crate) struct AutoDiffWithoutEnable;
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(codegen_llvm_lto_disallowed)]
|
||||
pub(crate) struct LtoDisallowed;
|
||||
|
|
|
@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction {
|
|||
LLVMPrintMessageAction = 1,
|
||||
LLVMReturnStatusAction = 2,
|
||||
}
|
||||
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub use self::Enzyme_AD::*;
|
||||
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub mod Enzyme_AD {
|
||||
use libc::c_void;
|
||||
extern "C" {
|
||||
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
|
||||
}
|
||||
extern "C" {
|
||||
static mut EnzymePrintPerf: c_void;
|
||||
static mut EnzymePrintActivity: c_void;
|
||||
static mut EnzymePrintType: c_void;
|
||||
static mut EnzymePrint: c_void;
|
||||
static mut EnzymeStrictAliasing: c_void;
|
||||
static mut looseTypeAnalysis: c_void;
|
||||
static mut EnzymeInline: c_void;
|
||||
static mut RustTypeRules: c_void;
|
||||
}
|
||||
pub fn set_print_perf(print: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_print_activity(print: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_print_type(print: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_print(print: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_strict_aliasing(strict: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_loose_types(loose: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_inline(val: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
|
||||
}
|
||||
}
|
||||
pub fn set_rust_rules(val: bool) {
|
||||
unsafe {
|
||||
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
pub use self::Fallback_AD::*;
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
pub mod Fallback_AD {
|
||||
#![allow(unused_variables)]
|
||||
|
||||
pub fn set_inline(val: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_print_perf(print: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_print_activity(print: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_print_type(print: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_print(print: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_strict_aliasing(strict: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_loose_types(loose: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
pub fn set_rust_rules(val: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue