1
Fork 0

Auto merge of #130060 - EnzymeAD:enzyme-cg-llvm, r=oli-obk

Autodiff Upstreaming - rustc_codegen_llvm changes

Now that the autodiff/Enzyme backend is merged, this is an upstream PR for the `rustc_codegen_llvm` changes.
It also includes small changes to three files under `compiler/rustc_ast`, which overlap with my frontend PR (https://github.com/rust-lang/rust/pull/129458).
Here I only include minimal definitions of structs and enums to be able to build this backend code.
The same goes for minimal changes to `compiler/rustc_codegen_ssa`, the majority of changes there will be in another PR, once either this or the frontend gets merged.

We currently have 68 files left to merge, 19 in the frontend PR, 21 (+3 from the frontend) in this PR, and then ~30 in the middle-end.

This PR is large because it includes two of my three large files (~800 loc each). I could also first only upstream enzyme_ffi.rs, but I think people might want to see some use of these bindings in the same PR?

To already highlight the things which reviewers might want to discuss:

1) `enzyme_ffi.rs`: I do have a fallback module to make sure that we don't link rustc against Enzyme when we build rustc without autodiff support.

2) `add_panic_msg_to_global` was a pain to write and I currently can't even use it. Enzyme writes gradients into shadow memory. Pass in one float scalar? We'll allocate and return an extra float telling you how this float affected the output. Pass in a slice of floats? We'll let you allocate the vector and pass in a mutable reference to a float slice, we'll then write the gradient into that slice. It should be at least as large as your original slice, so we check that and panic if not. Currently we panic silently, but I already generate a nicer panic message with this function. I just don't know how to print it to the user. yet. I discussed this with a few rustc devs and the best we could come up with (for now), was to look for mangled panic calls in the IR and pick one, which works surprisingly reliably. If someone knows a good way to clean this up and print the panic message I'm all in, otherwise I can remove the code that writes the nicer panic message and keep the silent panic, since it's enough for soundness. Especially since this PR is already a bit larger.

3) `SanitizeHWAddress`: When differentiating C++, Enzyme can use TBAA to "understand" enums/unions, but for Rust we don't have this information. LLVM might to speculative loads which (without TBAA) confuse Enzyme, so we disable those with this attribute. This attribute is only set during the first opt run before Enzyme differentiates code. We then remove it again once we are done with autodiff and run the opt pipeline a second time. Since enums are everywhere in Rust, support for them is crucial, but if this looks too cursed I can remove these ~100 lines and keep them in my fork for now, we can then discuss them separately to make this PR simpler?

4) Duplicated llvm-opt runs: Differentiating already optimized code (and being able to do additional optimizations on the fly, e.g. for GPU code) is _the_ reason why Enzyme is so fast, so the compile time is acceptable for autodiff users:  https://enzyme.mit.edu/talks/Publications/ (There are also algorithmic issues in Enzyme core which are more serious than running opt twice).

5) I assume that if we merge these minimal cg_ssa changes here already, I also need to fix the other backends (GCC and cliff) to have dummy implementations, correct?

6) *I'm happy to split this PR up further if reviewers have recommendations on how to.*

For the full implementation, see: https://github.com/rust-lang/rust/pull/129175

Tracking:

- https://github.com/rust-lang/rust/issues/124509
This commit is contained in:
bors 2025-01-02 00:20:57 +00:00
commit 504f4f5275
17 changed files with 610 additions and 28 deletions

View file

@ -6,7 +6,6 @@
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use std::str::FromStr; use std::str::FromStr;
use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P; use crate::ptr::P;
use crate::{Ty, TyKind}; use crate::{Ty, TyKind};
@ -79,10 +78,6 @@ pub struct AutoDiffItem {
/// The name of the function being generated /// The name of the function being generated
pub target: String, pub target: String,
pub attrs: AutoDiffAttrs, pub attrs: AutoDiffAttrs,
/// Describe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Describe the memory layout of the output type
pub output: TypeTree,
} }
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs { pub struct AutoDiffAttrs {
@ -262,22 +257,14 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source) !matches!(self.mode, DiffMode::Error | DiffMode::Source)
} }
pub fn into_item( pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
self, AutoDiffItem { source, target, attrs: self }
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
} }
} }
impl fmt::Display for AutoDiffItem { impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?; write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?; write!(f, " with attributes: {:?}", self.attrs)
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
} }
} }

View file

@ -93,6 +93,7 @@ use gccjit::{CType, Context, OptimizationLevel};
#[cfg(feature = "master")] #[cfg(feature = "master")]
use gccjit::{TargetInfo, Version}; use gccjit::{TargetInfo, Version};
use rustc_ast::expand::allocator::AllocatorKind; use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn, CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@ -439,6 +440,15 @@ impl WriteBackendMethods for GccCodegenBackend {
) -> Result<ModuleCodegen<Self::Module>, FatalError> { ) -> Result<ModuleCodegen<Self::Module>, FatalError> {
back::write::link(cgcx, dcx, modules) back::write::link(cgcx, dcx, modules)
} }
fn autodiff(
_cgcx: &CodegenContext<Self>,
_tcx: TyCtxt<'_>,
_module: &ModuleCodegen<Self::Module>,
_diff_fncs: Vec<AutoDiffItem>,
_config: &ModuleConfig,
) -> Result<(), FatalError> {
unimplemented!()
}
} }
/// This is the entrypoint for a hot plugged rustc_codegen_gccjit /// This is the entrypoint for a hot plugged rustc_codegen_gccjit

View file

@ -1,3 +1,5 @@
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err} codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
codegen_llvm_dynamic_linking_with_lto = codegen_llvm_dynamic_linking_with_lto =
@ -47,6 +49,8 @@ codegen_llvm_parse_bitcode_with_llvm_err = failed to parse bitcode for LTO modul
codegen_llvm_parse_target_machine_config = codegen_llvm_parse_target_machine_config =
failed to parse target machine config to target machine: {$error} failed to parse target machine config to target machine: {$error}
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context
codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err} codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err}

View file

@ -604,7 +604,14 @@ pub(crate) fn run_pass_manager(
debug!("running the pass manager"); debug!("running the pass manager");
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
let first_run = true;
debug!("running llvm pm opt pipeline");
unsafe {
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
}
debug!("lto done"); debug!("lto done");
Ok(()) Ok(())
} }

View file

@ -27,7 +27,7 @@ use rustc_session::config::{
}; };
use rustc_span::{BytePos, InnerSpan, Pos, SpanData, SyntaxContext, sym}; use rustc_span::{BytePos, InnerSpan, Pos, SpanData, SyntaxContext, sym};
use rustc_target::spec::{CodeModel, FloatAbi, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel}; use rustc_target::spec::{CodeModel, FloatAbi, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
use tracing::debug; use tracing::{debug, trace};
use crate::back::lto::ThinBuffer; use crate::back::lto::ThinBuffer;
use crate::back::owned_target_machine::OwnedTargetMachine; use crate::back::owned_target_machine::OwnedTargetMachine;
@ -537,9 +537,35 @@ pub(crate) unsafe fn llvm_optimize(
config: &ModuleConfig, config: &ModuleConfig,
opt_level: config::OptLevel, opt_level: config::OptLevel,
opt_stage: llvm::OptStage, opt_stage: llvm::OptStage,
skip_size_increasing_opts: bool,
) -> Result<(), FatalError> { ) -> Result<(), FatalError> {
let unroll_loops = // Enzyme:
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
// source code. However, benchmarks show that optimizations increasing the code size
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
// and finally re-optimize the module, now with all optimizations available.
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
// differentiated.
let unroll_loops;
let vectorize_slp;
let vectorize_loop;
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
// we should make this more granular, or at least check that the user has at least one autodiff
// call in their code, to justify altering the compilation pipeline.
if skip_size_increasing_opts && cfg!(llvm_enzyme) {
unroll_loops = false;
vectorize_slp = false;
vectorize_loop = false;
} else {
unroll_loops =
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
vectorize_slp = config.vectorize_slp;
vectorize_loop = config.vectorize_loop;
}
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop);
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
let pgo_gen_path = get_pgo_gen_path(config); let pgo_gen_path = get_pgo_gen_path(config);
let pgo_use_path = get_pgo_use_path(config); let pgo_use_path = get_pgo_use_path(config);
@ -603,8 +629,8 @@ pub(crate) unsafe fn llvm_optimize(
using_thin_buffers, using_thin_buffers,
config.merge_functions, config.merge_functions,
unroll_loops, unroll_loops,
config.vectorize_slp, vectorize_slp,
config.vectorize_loop, vectorize_loop,
config.no_builtins, config.no_builtins,
config.emit_lifetime_markers, config.emit_lifetime_markers,
sanitizer_options.as_ref(), sanitizer_options.as_ref(),
@ -648,6 +674,8 @@ pub(crate) unsafe fn optimize(
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) }; unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
} }
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
if let Some(opt_level) = config.opt_level { if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto { let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO, Lto::Fat => llvm::OptStage::PreLinkFatLTO,
@ -655,7 +683,20 @@ pub(crate) unsafe fn optimize(
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO, _ => llvm::OptStage::PreLinkNoLTO,
}; };
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
// If we know that we will later run AD, then we disable vectorization and loop unrolling
let skip_size_increasing_opts = cfg!(llvm_enzyme);
return unsafe {
llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)
};
} }
Ok(()) Ok(())
} }

View file

@ -2,6 +2,8 @@ use std::borrow::Cow;
use std::ops::Deref; use std::ops::Deref;
use std::{iter, ptr}; use std::{iter, ptr};
pub(crate) mod autodiff;
use libc::{c_char, c_uint}; use libc::{c_char, c_uint};
use rustc_abi as abi; use rustc_abi as abi;
use rustc_abi::{Align, Size, WrappingRange}; use rustc_abi::{Align, Size, WrappingRange};

View file

@ -0,0 +1,344 @@
use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_errors::FatalError;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::Lto;
use tracing::{debug, trace};
use crate::back::write::{llvm_err, llvm_optimize};
use crate::builder::Builder;
use crate::declare::declare_raw_fn;
use crate::errors::LlvmError;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
fn get_params(fnc: &Value) -> Vec<&Value> {
unsafe {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
fnc_args.set_len(param_num);
fnc_args
}
}
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
/// function with expected naming and calling conventions[^1] which will be
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
/// function and handle the differences between the Rust calling convention and
/// Enzyme.
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
fn generate_enzyme_call<'ll, 'tcx>(
cx: &context::CodegenCx<'ll, 'tcx>,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
attrs: AutoDiffAttrs,
) {
let inputs = attrs.input_activity;
let output = attrs.ret_activity;
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
// it will handle higher-order derivatives correctly automatically (in theory). Currently
// higher-order derivatives fail, so we should debug that before adjusting this code.
let mut ad_name: String = match attrs.mode {
DiffMode::Forward => "__enzyme_fwddiff",
DiffMode::Reverse => "__enzyme_autodiff",
DiffMode::ForwardFirst => "__enzyme_fwddiff",
DiffMode::ReverseFirst => "__enzyme_autodiff",
_ => panic!("logic bug in autodiff, unrecognized mode"),
}
.to_string();
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
let name = llvm::get_value_name(outer_fn);
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
ad_name.push_str(outer_fn_name.to_string().as_str());
// Let us assume the user wrote the following function square:
//
// ```llvm
// define double @square(double %x) {
// entry:
// %0 = fmul double %x, %x
// ret double %0
// }
// ```
//
// The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
// Our macro generates the following placeholder code (slightly simplified):
//
// ```llvm
// define double @dsquare(double %x) {
// ; placeholder code
// return 0.0;
// }
// ```
//
// so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
// code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
// Again, the arguments to all functions are slightly simplified.
// ```llvm
// declare double @__enzyme_autodiff_square(...)
//
// define double @dsquare(double %x) {
// entry:
// %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
// ret double %0
// }
// ```
unsafe {
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
// LLVM can figure out the input types on it's own, so we take a shortcut here.
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_raw_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
llvm::UnnamedAddr::No,
llvm::Visibility::Default,
enzyme_ty,
);
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
// do it's work.
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
// first, remove all calls from fnc
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
let br = llvm::LLVMRustGetTerminator(entry);
llvm::LLVMRustEraseInstFromParent(br);
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
let mut builder = Builder::build(cx, entry);
let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap();
match output {
DiffActivity::Dual => {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
DiffActivity::Active => {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
_ => {}
}
trace!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
// need to match those.
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
// using iterators and peek()?
let mut outer_pos: usize = 0;
let mut activity_pos = 0;
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
while activity_pos < inputs.len() {
let activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the
// gradient.
let (activity, duplicated): (&Metadata, bool) = match activity {
DiffActivity::None => panic!("not a valid input activity"),
DiffActivity::Const => (enzyme_const, false),
DiffActivity::Active => (enzyme_out, false),
DiffActivity::ActiveOnly => (enzyme_out, false),
DiffActivity::Dual => (enzyme_dup, true),
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
DiffActivity::Duplicated => (enzyme_dup, true),
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
DiffActivity::FakeActivitySize => (enzyme_const, false),
};
let outer_arg = outer_args[outer_pos];
args.push(cx.get_metadata_value(activity));
args.push(outer_arg);
if duplicated {
// We know that duplicated args by construction have a following argument,
// so this can not be out of bounds.
let next_outer_arg = outer_args[outer_pos + 1];
let next_outer_ty = cx.val_ty(next_outer_arg);
// FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
// vectors behind references (&Vec<T>) are already supported. Users can not pass a
// Vec by value for reverse mode, so this would only help forward mode autodiff.
let slice = {
if activity_pos + 1 >= inputs.len() {
// If there is no arg following our ptr, it also can't be a slice,
// since that would lead to a ptr, int pair.
false
} else {
let next_activity = inputs[activity_pos + 1];
// We analyze the MIR types and add this dummy activity if we visit a slice.
next_activity == DiffActivity::FakeActivitySize
}
};
if slice {
// A duplicated slice will have the following two outer_fn arguments:
// (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer);
let next_outer_arg2 = outer_args[outer_pos + 2];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer);
let next_outer_arg3 = outer_args[outer_pos + 3];
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer);
args.push(next_outer_arg2);
args.push(cx.get_metadata_value(enzyme_const));
args.push(next_outer_arg);
outer_pos += 4;
activity_pos += 2;
} else {
// A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...).
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
args.push(next_outer_arg);
outer_pos += 2;
activity_pos += 1;
}
} else {
// We do not differentiate with resprect to this argument.
// We already added the metadata and argument above, so just increase the counters.
outer_pos += 1;
activity_pos += 1;
}
}
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attachted to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
// dummy code which we inserted at a higher level.
// FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
// and how to best improve it for enzyme core and rust-enzyme.
let md_ty = cx.get_md_kind_id("dbg");
if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
.expect("failed to get instruction metadata");
let md_todiff = cx.get_metadata_value(md);
llvm::LLVMSetMetadata(call, md_ty, md_todiff);
} else {
// We don't panic, since depending on whether we are in debug or release mode, we might
// have no debug info to copy, which would then be ok.
trace!("no dbg info");
}
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstBefore(entry, last_inst);
llvm::LLVMRustEraseInstFromParent(last_inst);
if cx.val_ty(outer_fn) != cx.type_void() {
builder.ret(call);
} else {
builder.ret_void();
}
// Let's crash in case that we messed something up above and generated invalid IR.
llvm::LLVMRustVerifyFunction(
outer_fn,
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
);
}
}
pub(crate) fn differentiate<'ll, 'tcx>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
tcx: TyCtxt<'tcx>,
diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
for item in &diff_items {
trace!("{}", item);
}
let diag_handler = cgcx.create_dcx();
let (_, cgus) = tcx.collect_and_partition_mono_items(());
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
// Before dumping the module, we want all the TypeTrees to become part of the module.
for item in diff_items.iter() {
let name = item.source.clone();
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
let Some(fn_def) = fn_def else {
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find source function".to_owned(),
}));
};
debug!(?item.target);
let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
let Some(fn_target) = fn_target else {
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find target function".to_owned(),
}));
};
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
}
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
if let Some(opt_level) = config.opt_level {
let opt_stage = match cgcx.lto {
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
_ => llvm::OptStage::PreLinkNoLTO,
};
// This is our second opt call, so now we run all opts,
// to make sure we get the best performance.
let skip_size_increasing_opts = false;
trace!("running Module Optimization after differentiation");
unsafe {
llvm_optimize(
cgcx,
diag_handler.handle(),
module,
config,
opt_level,
opt_stage,
skip_size_increasing_opts,
)?
};
}
trace!("done with differentiate()");
Ok(())
}

View file

@ -1,6 +1,6 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::ffi::{CStr, c_uint}; use std::ffi::{CStr, c_char, c_uint};
use std::str; use std::str;
use rustc_abi::{HasDataLayout, TargetDataLayout, VariantIdx}; use rustc_abi::{HasDataLayout, TargetDataLayout, VariantIdx};
@ -600,6 +600,31 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
llvm::set_section(g, c"llvm.metadata"); llvm::set_section(g, c"llvm.metadata");
} }
} }
pub(crate) fn get_metadata_value(&self, metadata: &'ll Metadata) -> &'ll Value {
unsafe { llvm::LLVMMetadataAsValue(self.llcx, metadata) }
}
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMGetNamedFunction(self.llmod, name.as_ptr()) }
}
pub(crate) fn get_md_kind_id(&self, name: &str) -> u32 {
unsafe {
llvm::LLVMGetMDKindIDInContext(
self.llcx,
name.as_ptr() as *const c_char,
name.len() as c_uint,
)
}
}
pub(crate) fn create_metadata(&self, name: String) -> Option<&'ll Metadata> {
Some(unsafe {
llvm::LLVMMDStringInContext2(self.llcx, name.as_ptr() as *const c_char, name.len())
})
}
} }
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> { impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

View file

@ -32,7 +32,7 @@ use crate::{attributes, llvm};
/// ///
/// If theres a value with the same name already declared, the function will /// If theres a value with the same name already declared, the function will
/// update the declaration and return existing Value instead. /// update the declaration and return existing Value instead.
fn declare_raw_fn<'ll>( pub(crate) fn declare_raw_fn<'ll>(
cx: &CodegenCx<'ll, '_>, cx: &CodegenCx<'ll, '_>,
name: &str, name: &str,
callconv: llvm::CallConv, callconv: llvm::CallConv,

View file

@ -89,6 +89,11 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
} }
} }
#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_lto)]
#[note]
pub(crate) struct AutoDiffWithoutLTO;
#[derive(Diagnostic)] #[derive(Diagnostic)]
#[diag(codegen_llvm_lto_disallowed)] #[diag(codegen_llvm_lto_disallowed)]
pub(crate) struct LtoDisallowed; pub(crate) struct LtoDisallowed;
@ -131,6 +136,8 @@ pub enum LlvmError<'a> {
PrepareThinLtoModule, PrepareThinLtoModule,
#[diag(codegen_llvm_parse_bitcode)] #[diag(codegen_llvm_parse_bitcode)]
ParseBitcode, ParseBitcode,
#[diag(codegen_llvm_prepare_autodiff)]
PrepareAutoDiff { src: String, target: String, error: String },
} }
pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String);
@ -152,6 +159,7 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for WithLlvmError<'_> {
} }
PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err,
ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err,
PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err,
}; };
self.0 self.0
.into_diag(dcx, level) .into_diag(dcx, level)

View file

@ -28,9 +28,10 @@ use std::mem::ManuallyDrop;
use back::owned_target_machine::OwnedTargetMachine; use back::owned_target_machine::OwnedTargetMachine;
use back::write::{create_informational_target_machine, create_target_machine}; use back::write::{create_informational_target_machine, create_target_machine};
use errors::ParseTargetMachineConfig; use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig};
pub use llvm_util::target_features_cfg; pub use llvm_util::target_features_cfg;
use rustc_ast::expand::allocator::AllocatorKind; use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn,
@ -44,7 +45,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId};
use rustc_middle::ty::TyCtxt; use rustc_middle::ty::TyCtxt;
use rustc_middle::util::Providers; use rustc_middle::util::Providers;
use rustc_session::Session; use rustc_session::Session;
use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest};
use rustc_span::Symbol; use rustc_span::Symbol;
mod back { mod back {
@ -233,6 +234,20 @@ impl WriteBackendMethods for LlvmCodegenBackend {
fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer) { fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer) {
(module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod()))
} }
/// Generate autodiff rules
fn autodiff(
cgcx: &CodegenContext<Self>,
tcx: TyCtxt<'_>,
module: &ModuleCodegen<Self::Module>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
if cgcx.lto != Lto::Fat {
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO));
}
builder::autodiff::differentiate(module, cgcx, tcx, diff_fncs, config)
}
} }
unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis

View file

@ -0,0 +1,29 @@
#![allow(non_camel_case_types)]
use libc::{c_char, c_uint};
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
use crate::llvm::Bool;
extern "C" {
// Enzyme
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
pub fn LLVMRustEraseInstFromParent(V: &Value);
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
pub fn LLVMGetReturnType(T: &Type) -> &Type;
pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value);
pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
}
#[repr(C)]
#[derive(Copy, Clone, PartialEq)]
pub enum LLVMRustVerifierFailureAction {
LLVMAbortProcessAction = 0,
LLVMPrintMessageAction = 1,
LLVMReturnStatusAction = 2,
}

View file

@ -99,7 +99,7 @@ pub enum ModuleFlagMergeBehavior {
/// LLVM CallingConv::ID. Should we wrap this? /// LLVM CallingConv::ID. Should we wrap this?
/// ///
/// See <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h> /// See <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h>
#[derive(Copy, Clone, PartialEq, Debug)] #[derive(Copy, Clone, PartialEq, Debug, TryFromU32)]
#[repr(C)] #[repr(C)]
pub enum CallConv { pub enum CallConv {
CCallConv = 0, CCallConv = 0,

View file

@ -22,8 +22,11 @@ use crate::common::AsCCharPtr;
pub mod archive_ro; pub mod archive_ro;
pub mod diagnostic; pub mod diagnostic;
pub mod enzyme_ffi;
mod ffi; mod ffi;
pub use self::enzyme_ffi::*;
impl LLVMRustResult { impl LLVMRustResult {
pub fn into_result(self) -> Result<(), ()> { pub fn into_result(self) -> Result<(), ()> {
match self { match self {

View file

@ -1,11 +1,14 @@
use std::ffi::CString; use std::ffi::CString;
use std::sync::Arc; use std::sync::Arc;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::memmap::Mmap; use rustc_data_structures::memmap::Mmap;
use rustc_errors::FatalError; use rustc_errors::FatalError;
use rustc_middle::ty::TyCtxt;
use super::write::CodegenContext; use super::write::CodegenContext;
use crate::ModuleCodegen; use crate::ModuleCodegen;
use crate::back::write::ModuleConfig;
use crate::traits::*; use crate::traits::*;
pub struct ThinModule<B: WriteBackendMethods> { pub struct ThinModule<B: WriteBackendMethods> {
@ -81,6 +84,24 @@ impl<B: WriteBackendMethods> LtoModuleCodegen<B> {
LtoModuleCodegen::Thin(ref m) => m.cost(), LtoModuleCodegen::Thin(ref m) => m.cost(),
} }
} }
/// Run autodiff on Fat LTO module
pub unsafe fn autodiff(
self,
cgcx: &CodegenContext<B>,
tcx: TyCtxt<'_>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<LtoModuleCodegen<B>, FatalError> {
match &self {
LtoModuleCodegen::Fat(module) => {
B::autodiff(cgcx, tcx, &module, diff_fncs, config)?;
}
_ => panic!("autodiff called with non-fat LTO module"),
}
Ok(self)
}
} }
pub enum SerializedModule<M: ModuleBufferMethods> { pub enum SerializedModule<M: ModuleBufferMethods> {

View file

@ -1,5 +1,7 @@
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_errors::{DiagCtxtHandle, FatalError};
use rustc_middle::dep_graph::WorkProduct; use rustc_middle::dep_graph::WorkProduct;
use rustc_middle::ty::TyCtxt;
use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig};
@ -61,6 +63,13 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
want_summary: bool, want_summary: bool,
) -> (String, Self::ThinBuffer); ) -> (String, Self::ThinBuffer);
fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer); fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer);
fn autodiff(
cgcx: &CodegenContext<Self>,
tcx: TyCtxt<'_>,
module: &ModuleCodegen<Self::Module>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError>;
} }
pub trait ThinBufferMethods: Send + Sync { pub trait ThinBufferMethods: Send + Sync {

View file

@ -1,5 +1,6 @@
#include "LLVMWrapper.h" #include "LLVMWrapper.h"
#include "llvm-c/Analysis.h"
#include "llvm-c/Core.h" #include "llvm-c/Core.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@ -165,6 +166,30 @@ extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen))); return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));
} }
enum class LLVMRustVerifierFailureAction {
AbortProcessAction = 0,
PrintMessageAction = 1,
ReturnStatusAction = 2,
};
static LLVMVerifierFailureAction
fromRust(LLVMRustVerifierFailureAction Action) {
switch (Action) {
case LLVMRustVerifierFailureAction::AbortProcessAction:
return LLVMAbortProcessAction;
case LLVMRustVerifierFailureAction::PrintMessageAction:
return LLVMPrintMessageAction;
case LLVMRustVerifierFailureAction::ReturnStatusAction:
return LLVMReturnStatusAction;
}
report_fatal_error("Invalid LLVMVerifierFailureAction value!");
}
extern "C" LLVMBool
LLVMRustVerifyFunction(LLVMValueRef Fn, LLVMRustVerifierFailureAction Action) {
return LLVMVerifyFunction(Fn, fromRust(Action));
}
enum class LLVMRustTailCallKind { enum class LLVMRustTailCallKind {
None, None,
Tail, Tail,
@ -388,6 +413,17 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr,
AddAttributes(Call, Index, Attrs, AttrsLen); AddAttributes(Call, Index, Attrs, AttrsLen);
} }
extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) {
Instruction *ret = unwrap(BB)->getTerminator();
return wrap(ret);
}
extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) {
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
I->eraseFromParent();
}
}
extern "C" LLVMAttributeRef extern "C" LLVMAttributeRef
LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) { LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) {
return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr)));
@ -954,6 +990,47 @@ extern "C" void LLVMRustAddModuleFlagString(
MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen))); MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen)));
} }
extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) {
auto Point = unwrap(BB)->rbegin();
if (Point != unwrap(BB)->rend())
return wrap(&*Point);
return nullptr;
}
extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
auto &BB = *unwrap(bb);
auto &Inst = *unwrap<Instruction>(I);
auto It = BB.begin();
while (&*It != &Inst)
++It;
// Make sure we found the Instruction.
assert(It != BB.end());
// We don't want to erase the instruction itself.
It--;
// Delete in rev order to ensure no dangling references.
while (It != BB.begin()) {
auto Prev = std::prev(It);
It->eraseFromParent();
It = Prev;
}
It->eraseFromParent();
}
extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) {
if (auto *I = dyn_cast<Instruction>(unwrap<Value>(inst))) {
return I->hasMetadata(kindID);
}
return false;
}
extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) {
if (auto *I = dyn_cast<Instruction>(unwrap<Value>(x))) {
auto *MD = I->getDebugLoc().getAsMDNode();
return wrap(MD);
}
return nullptr;
}
extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind, extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind,
LLVMMetadataRef MD) { LLVMMetadataRef MD) {
unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD)); unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD));