Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk

Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle

This PR should not be merged until the rustc_codegen_llvm part is merged.
I will also alter it a little based on what get's shaved off from the cg_llvm PR,
and address some of the feedback I received in the other PR (including cleanups).

I am putting it already up to
1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and
2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost.

Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please?
For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook

Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: https://github.com/EnzymeAD/rust/issues/173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me.

Tracking:

- https://github.com/rust-lang/rust/issues/124509

r? `@jieyouxu`
This commit is contained in:
Jacob Pratt 2025-01-31 00:26:30 -05:00 committed by GitHub
commit c19c4b91f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 482 additions and 38 deletions

View file

@ -4234,6 +4234,7 @@ name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_abi",
"rustc_ast",
"rustc_attr_parsing",
"rustc_data_structures",
"rustc_errors",
@ -4243,6 +4244,7 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",

View file

@ -79,6 +79,7 @@ pub struct AutoDiffItem {
pub target: String,
pub attrs: AutoDiffAttrs,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
@ -231,7 +232,7 @@ impl AutoDiffAttrs {
self.ret_activity == DiffActivity::ActiveOnly
}
pub fn error() -> Self {
pub const fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
ret_activity: DiffActivity::None,

View file

@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>(
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
let name = llvm::get_value_name(outer_fn);
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
ad_name.push_str(outer_fn_name.to_string().as_str());
let outer_fn_name = std::str::from_utf8(name).unwrap();
ad_name.push_str(outer_fn_name);
// Let us assume the user wrote the following function square:
//
@ -255,14 +255,14 @@ fn generate_enzyme_call<'ll>(
// have no debug info to copy, which would then be ok.
trace!("no dbg info");
}
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstBefore(entry, last_inst);
llvm::LLVMRustEraseInstFromParent(last_inst);
if cx.val_ty(outer_fn) != cx.type_void() {
builder.ret(call);
} else {
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
if cx.val_ty(call) == cx.type_void() {
builder.ret_void();
} else {
builder.ret(call);
}
// Let's crash in case that we messed something up above and generated invalid IR.

View file

@ -298,7 +298,7 @@ struct UsageSets<'tcx> {
/// Prepare sets of definitions that are relevant to deciding whether something
/// is an "unused function" for coverage purposes.
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
let MonoItemPartitions { all_mono_items, codegen_units } =
let MonoItemPartitions { all_mono_items, codegen_units, .. } =
tcx.collect_and_partition_mono_items(());
// Obtain a MIR body for each function participating in codegen, via an

View file

@ -7,11 +7,13 @@ use crate::llvm::Bool;
extern "C" {
// Enzyme
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
pub fn LLVMRustEraseInstFromParent(V: &Value);
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
pub fn LLVMDumpModule(M: &Module);
pub fn LLVMDumpValue(V: &Value);
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;

View file

@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro
codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty
codegen_ssa_cgu_not_recorded =

View file

@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
use std::{fs, io, mem, str, thread};
use rustc_ast::attr;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
use rustc_data_structures::jobserver::{self, Acquired};
use rustc_data_structures::memmap::Mmap;
@ -40,7 +41,7 @@ use tracing::debug;
use super::link::{self, ensure_removed};
use super::lto::{self, SerializedModule};
use super::symbol_export::symbol_name_for_instance_in_crate;
use crate::errors::ErrorCreatingRemarkDir;
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
use crate::traits::*;
use crate::{
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@ -118,6 +119,7 @@ pub struct ModuleConfig {
pub merge_functions: bool,
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
}
impl ModuleConfig {
@ -266,6 +268,7 @@ impl ModuleConfig {
emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
}
}
@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
fn generate_lto_work<B: ExtraBackendMethods>(
cgcx: &CodegenContext<B>,
autodiff: Vec<AutoDiffItem>,
needs_fat_lto: Vec<FatLtoInput<B>>,
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
if !needs_fat_lto.is_empty() {
assert!(needs_thin_lto.is_empty());
let module =
let mut module =
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
if cgcx.lto == Lto::Fat {
let config = cgcx.config(ModuleKind::Regular);
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
}
// We are adding a single work item, so the cost doesn't matter.
vec![(WorkItem::LTO(module), 0)]
} else {
if !autodiff.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
assert!(needs_fat_lto.is_empty());
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
.unwrap_or_else(|e| e.raise());
@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
/// Sent from a backend worker thread.
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
AddAutoDiffItems(Vec<AutoDiffItem>),
/// The frontend has finished generating something (backend IR or a
/// post-LTO artifact) for a codegen unit, and it should be passed to the
/// backend. Sent from the main thread.
@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
// This is where we collect codegen units that have gone all the way
// through codegen and LLVM.
let mut autodiff_items = Vec::new();
let mut compiled_modules = vec![];
let mut compiled_allocator_module = None;
let mut needs_link = Vec::new();
@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
let needs_thin_lto = mem::take(&mut needs_thin_lto);
let import_only_modules = mem::take(&mut lto_import_only_modules);
for (work, cost) in
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
{
for (work, cost) in generate_lto_work(
&cgcx,
autodiff_items.clone(),
needs_fat_lto,
needs_thin_lto,
import_only_modules,
) {
let insertion_index = work_items
.binary_search_by_key(&cost, |&(_, cost)| cost)
.unwrap_or_else(|e| e);
@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
main_thread_state = MainThreadState::Idle;
}
Message::AddAutoDiffItems(mut items) => {
autodiff_items.append(&mut items);
}
Message::CodegenComplete => {
if codegen_state != Aborted {
codegen_state = Completed;
@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
}
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
}
pub(crate) fn check_for_errors(&self, sess: &Session) {
self.shared_emitter_main.check(sess, false);
}

View file

@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
use rustc_middle::middle::exported_symbols::SymbolExportKind;
use rustc_middle::middle::{exported_symbols, lang_items};
use rustc_middle::mir::BinOp;
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
use rustc_middle::query::Providers;
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
@ -624,7 +624,9 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
// Run the monomorphization collector and partition the collected items into
// codegen units.
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
tcx.collect_and_partition_mono_items(());
let autodiff_fncs = autodiff_items.to_vec();
// Force all codegen_unit queries so they are already either red or green
// when compile_codegen_unit accesses them. We are not able to re-execute
@ -695,6 +697,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
);
}
if !autodiff_fncs.is_empty() {
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
}
// For better throughput during parallel processing by LLVM, we used to sort
// CGUs largest to smallest. This would lead to better thread utilization
// by, for example, preventing a large CGU from being processed last and

View file

@ -1,5 +1,10 @@
use std::str::FromStr;
use rustc_ast::attr::list_contains_name;
use rustc_ast::{MetaItemInner, attr};
use rustc_ast::expand::autodiff_attrs::{
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
};
use rustc_ast::{MetaItem, MetaItemInner, attr};
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::codes::*;
@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
};
use rustc_middle::mir::mono::Linkage;
use rustc_middle::query::Providers;
use rustc_middle::span_bug;
use rustc_middle::ty::{self as ty, TyCtxt};
use rustc_session::parse::feature_err;
use rustc_session::{Session, lint};
@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
}
// If our rustc version supports autodiff/enzyme, then we call our handler
// to check for any `#[rustc_autodiff(...)]` attributes.
if cfg!(llvm_enzyme) {
let ad = autodiff_attrs(tcx, did.into());
codegen_fn_attrs.autodiff_item = ad;
}
// When `no_builtins` is applied at the crate level, we should add the
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
}
}
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
let attrs =
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
// check for exactly one autodiff attribute on placeholder functions.
// There should only be one, since we generate a new placeholder per ad macro.
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
// looks strange e.g. under cargo-expand.
let attr = match &attrs[..] {
[] => return None,
[attr] => attr,
// These two attributes are the same and unfortunately duplicated due to a previous bug.
[attr, _attr2] => attr,
_ => {
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
//branch above.
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
}
};
let list = attr.meta_item_list().unwrap_or_default();
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
if list.is_empty() {
return Some(AutoDiffAttrs::source());
}
let [mode, input_activities @ .., ret_activity] = &list[..] else {
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
};
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
p1.segments.first().unwrap().ident
} else {
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
};
// parse mode
let mode = match mode.as_str() {
"Forward" => DiffMode::Forward,
"Reverse" => DiffMode::Reverse,
"ForwardFirst" => DiffMode::ForwardFirst,
"ReverseFirst" => DiffMode::ReverseFirst,
_ => {
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
}
};
// First read the ret symbol from the attribute
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
p1.segments.first().unwrap().ident
} else {
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
};
// Then parse it into an actual DiffActivity
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
span_bug!(ret_symbol.span, "invalid return activity");
};
// Now parse all the intermediate (input) activities
let mut arg_activities: Vec<DiffActivity> = vec![];
for arg in input_activities {
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
match p2.segments.first() {
Some(x) => x.ident,
None => {
span_bug!(
arg.span(),
"rustc_autodiff attribute must contain the input activity"
);
}
}
} else {
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
};
match DiffActivity::from_str(arg_symbol.as_str()) {
Ok(arg_activity) => arg_activities.push(arg_activity),
Err(_) => {
span_bug!(arg_symbol.span, "invalid input activity");
}
}
}
for &input in &arg_activities {
if !valid_input_activity(mode, input) {
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
}
}
if !valid_ret_activity(mode, ret_activity) {
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
}
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
}
pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
}

View file

@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
pub cgu_name: &'a str,
}
#[derive(Diagnostic)]
#[diag(codegen_ssa_autodiff_without_lto)]
pub struct AutodiffWithoutLto;
#[derive(Diagnostic)]
#[diag(codegen_ssa_unknown_reuse_kind)]
pub(crate) struct UnknownReuseKind {

View file

@ -8,12 +8,12 @@ use rustc_data_structures::profiling::TimePassesFormat;
use rustc_errors::emitter::HumanReadableErrorType;
use rustc_errors::{ColorConfig, registry};
use rustc_session::config::{
BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel, CoverageOptions,
DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation, Externs,
FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage, InstrumentXRay,
LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans, NextSolverConfig,
OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet, Passes,
PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath,
AutoDiff, BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel,
CoverageOptions, DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation,
Externs, FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage,
InstrumentXRay, LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans,
NextSolverConfig, OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet,
Passes, PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath,
SymbolManglingVersion, WasiExecModel, build_configuration, build_session_options,
rustc_optgroups,
};
@ -760,6 +760,7 @@ fn test_unstable_options_tracking_hash() {
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(assume_incomplete_release, true);
tracked!(autodiff, vec![AutoDiff::Print]);
tracked!(binary_dep_depinfo, true);
tracked!(box_noalias, false);
tracked!(

View file

@ -955,7 +955,8 @@ extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) {
return nullptr;
}
extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
extern "C" void LLVMRustEraseInstUntilInclusive(LLVMBasicBlockRef bb,
LLVMValueRef I) {
auto &BB = *unwrap(bb);
auto &Inst = *unwrap<Instruction>(I);
auto It = BB.begin();
@ -963,8 +964,6 @@ extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
++It;
// Make sure we found the Instruction.
assert(It != BB.end());
// We don't want to erase the instruction itself.
It--;
// Delete in rev order to ensure no dangling references.
while (It != BB.begin()) {
auto Prev = std::prev(It);

View file

@ -32,6 +32,8 @@ middle_assert_shl_overflow =
middle_assert_shr_overflow =
attempt to shift right by `{$val}`, which would overflow
middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe
middle_bounds_check =
index out of bounds: the length is {$len} but the index is {$index}
@ -107,6 +109,8 @@ middle_type_length_limit = reached the type-length limit while instantiating `{$
middle_unknown_layout =
the type `{$ty}` has an unknown layout
middle_unsupported_union = we don't support unions yet: '{$ty_name}'
middle_values_too_big =
values of the type `{$ty}` are too big for the target architecture

View file

@ -87,6 +87,7 @@ macro_rules! arena_types {
[] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>,
[decode] attribute: rustc_hir::Attribute,
[] name_set: rustc_data_structures::unord::UnordSet<rustc_span::Symbol>,
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
[] pats: rustc_middle::ty::PatternKind<'tcx>,

View file

@ -37,6 +37,20 @@ pub struct OpaqueHiddenTypeMismatch<'tcx> {
pub sub: TypeMismatchReason,
}
#[derive(Diagnostic)]
#[diag(middle_unsupported_union)]
pub struct UnsupportedUnion {
pub ty_name: String,
}
#[derive(Diagnostic)]
#[diag(middle_autodiff_unsafe_inner_const_ref)]
pub struct AutodiffUnsafeInnerConstRef {
#[primary_span]
pub span: Span,
pub ty: String,
}
#[derive(Subdiagnostic)]
pub enum TypeMismatchReason {
#[label(middle_conflict_types)]

View file

@ -1,4 +1,5 @@
use rustc_abi::Align;
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
use rustc_span::Symbol;
@ -52,6 +53,8 @@ pub struct CodegenFnAttrs {
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
/// the function entry.
pub patchable_function_entry: Option<PatchableFunctionEntry>,
/// For the `#[autodiff]` macros.
pub autodiff_item: Option<AutoDiffAttrs>,
}
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
@ -160,6 +163,7 @@ impl CodegenFnAttrs {
instruction_set: None,
alignment: None,
patchable_function_entry: None,
autodiff_item: None,
}
}

View file

@ -1,6 +1,7 @@
use std::fmt;
use std::hash::Hash;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_attr_parsing::InlineAttr;
use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN};
use rustc_data_structures::fingerprint::Fingerprint;
@ -246,6 +247,7 @@ impl ToStableHashKey<StableHashingContext<'_>> for MonoItem<'_> {
pub struct MonoItemPartitions<'tcx> {
pub codegen_units: &'tcx [CodegenUnit<'tcx>],
pub all_mono_items: &'tcx DefIdSet,
pub autodiff_items: &'tcx [AutoDiffItem],
}
#[derive(Debug, HashStable)]

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
# tidy-alphabetical-start
rustc_abi = { path = "../rustc_abi" }
rustc_ast = { path = "../rustc_ast" }
rustc_attr_parsing = { path = "../rustc_attr_parsing" }
rustc_data_structures = { path = "../rustc_data_structures" }
rustc_errors = { path = "../rustc_errors" }
@ -15,6 +16,7 @@ rustc_macros = { path = "../rustc_macros" }
rustc_middle = { path = "../rustc_middle" }
rustc_session = { path = "../rustc_session" }
rustc_span = { path = "../rustc_span" }
rustc_symbol_mangling = { path = "../rustc_symbol_mangling" }
rustc_target = { path = "../rustc_target" }
serde = "1"
serde_json = "1"

View file

@ -257,7 +257,7 @@ struct SharedState<'tcx> {
pub(crate) struct UsageMap<'tcx> {
// Maps every mono item to the mono items used by it.
used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
// Maps every mono item to the mono items that use it.
user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,

View file

@ -92,6 +92,8 @@
//! source-level module, functions from the same module will be available for
//! inlining, even when they are not marked `#[inline]`.
mod autodiff;
use std::cmp;
use std::collections::hash_map::Entry;
use std::fs::{self, File};
@ -251,7 +253,17 @@ where
can_export_generics,
always_export_generics,
);
if visibility == Visibility::Hidden && can_be_internalized {
// We can't differentiate something that got inlined.
let autodiff_active = cfg!(llvm_enzyme)
&& cx
.tcx
.codegen_fn_attrs(mono_item.def_id())
.autodiff_item
.as_ref()
.is_some_and(|ad| ad.is_active());
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
internalization_candidates.insert(mono_item);
}
let size_estimate = mono_item.size_estimate(cx.tcx);
@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
})
.collect();
let autodiff_mono_items: Vec<_> = items
.iter()
.filter_map(|item| match *item {
MonoItem::Fn(ref instance) => Some((item, instance)),
_ => None,
})
.collect();
let autodiff_items =
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
// Output monomorphization stats per def_id
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
if let Err(err) =
@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
}
}
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
MonoItemPartitions {
all_mono_items: tcx.arena.alloc(mono_items),
codegen_units,
autodiff_items,
}
}
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s

View file

@ -0,0 +1,121 @@
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};
use crate::partitioning::UsageMap;
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
}
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
// If rustc compiles the unmodified primal, we know that this copy of the function
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
// (or actually at all), so let's strip lifetimes when computing the layout.
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
let mut new_activities = vec![];
let mut new_positions = vec![];
for (i, ty) in x.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if ty.is_fn_ptr() {
// FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
// since Enzyme itself can handle them.
tcx.dcx().err("function pointers are currently not supported in autodiff");
}
if inner_ty.is_slice() {
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}
}
pub(crate) fn find_autodiff_source_functions<'tcx>(
tcx: TyCtxt<'tcx>,
usage_map: &UsageMap<'tcx>,
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
) -> Vec<AutoDiffItem> {
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
for (item, instance) in autodiff_mono_items {
let target_id = instance.def_id();
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
let Some(target_attrs) = cg_fn_attr else {
continue;
};
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
if target_attrs.is_source() {
trace!("source found: {:?}", target_id);
}
if !target_attrs.apply_autodiff() {
continue;
}
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
let source =
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
MonoItem::Fn(ref instance_s) => {
let source_id = instance_s.def_id();
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
&& ad.is_active()
{
return Some(instance_s);
}
None
}
_ => None,
});
let inst = match source {
Some(source) => source,
None => continue,
};
debug!("source_id: {:?}", inst.def_id());
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
assert!(fn_ty.is_fn());
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
let mut new_target_attrs = target_attrs.clone();
new_target_attrs.input_activity = input_activities;
let itm = new_target_attrs.into_item(symb, target_symbol);
autodiff_items.push(itm);
}
if !autodiff_items.is_empty() {
trace!("AUTODIFF ITEMS EXIST");
for item in &mut *autodiff_items {
trace!("{}", &item);
}
}
autodiff_items
}

View file

@ -189,6 +189,39 @@ pub enum CoverageLevel {
Mcdc,
}
/// The different settings that the `-Z autodiff` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
pub enum AutoDiff {
/// Print TypeAnalysis information
PrintTA,
/// Print ActivityAnalysis Information
PrintAA,
/// Print Performance Warnings from Enzyme
PrintPerf,
/// Combines the three print flags above.
Print,
/// Print the whole module, before running opts.
PrintModBefore,
/// Print the whole module just before we pass it to Enzyme.
/// For Debug purpose, prefer the OPT flag below
PrintModAfterOpts,
/// Print the module after Enzyme differentiated everything.
PrintModAfterEnzyme,
/// Enzyme's loose type debug helper (can cause incorrect gradients)
LooseTypes,
/// More flags
NoModOptAfter,
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
/// since we already optimize the whole module after Enzyme is done.
EnableFncOpt,
NoVecUnroll,
RuntimeActivity,
/// Runs Enzyme specific Inlining
Inline,
}
/// Settings for `-Z instrument-xray` flag.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct InstrumentXRay {
@ -2902,7 +2935,7 @@ pub(crate) mod dep_tracking {
};
use super::{
BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
@ -2950,6 +2983,7 @@ pub(crate) mod dep_tracking {
}
impl_dep_tracking_hash_via_hash!(
AutoDiff,
bool,
usize,
NonZero<usize>,

View file

@ -398,6 +398,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
@ -1029,6 +1030,38 @@ pub mod parse {
}
}
pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
let Some(v) = v else {
*slot = vec![];
return true;
};
let mut v: Vec<&str> = v.split(",").collect();
v.sort_unstable();
for &val in v.iter() {
let variant = match val {
"PrintTA" => AutoDiff::PrintTA,
"PrintAA" => AutoDiff::PrintAA,
"PrintPerf" => AutoDiff::PrintPerf,
"Print" => AutoDiff::Print,
"PrintModBefore" => AutoDiff::PrintModBefore,
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
"LooseTypes" => AutoDiff::LooseTypes,
"NoModOptAfter" => AutoDiff::NoModOptAfter,
"EnableFncOpt" => AutoDiff::EnableFncOpt,
"NoVecUnroll" => AutoDiff::NoVecUnroll,
"Inline" => AutoDiff::Inline,
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
return false;
}
};
slot.push(variant);
}
true
}
pub(crate) fn parse_instrument_coverage(
slot: &mut InstrumentCoverage,
v: Option<&str>,
@ -1736,6 +1769,22 @@ options! {
either `loaded` or `not-loaded`."),
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
"make cfg(version) treat the current version as incomplete (default: no)"),
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
"a list of optional autodiff flags to enable
Optional extra settings:
`=PrintTA`
`=PrintAA`
`=PrintPerf`
`=Print`
`=PrintModBefore`
`=PrintModAfterOpts`
`=PrintModAfterEnzyme`
`=LooseTypes`
`=NoModOptAfter`
`=EnableFncOpt`
`=NoVecUnroll`
`=Inline`
Multiple options can be combined with commas."),
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \

View file

@ -502,7 +502,6 @@ symbols! {
augmented_assignments,
auto_traits,
autodiff,
autodiff_fallback,
automatically_derived,
avx,
avx512_target_feature,
@ -568,7 +567,6 @@ symbols! {
cfg_accessible,
cfg_attr,
cfg_attr_multi,
cfg_autodiff_fallback,
cfg_boolean_literals,
cfg_doctest,
cfg_emscripten_wasm_eh,

View file

@ -1049,9 +1049,12 @@ pub fn rustc_cargo(
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
cargo.rustflag("-Zon-broken-pipe=kill");
if builder.config.llvm_enzyme {
cargo.rustflag("-l").rustflag("Enzyme-19");
}
// We temporarily disable linking here as part of some refactoring.
// This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now.
// In a follow-up PR, we will re-enable linking here and load the pass for them.
//if builder.config.llvm_enzyme {
// cargo.rustflag("-l").rustflag("Enzyme-19");
//}
// Building with protected visibility reduces the number of dynamic relocations needed, giving
// us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object

View file

@ -0,0 +1,23 @@
# `autodiff`
The tracking issue for this feature is: [#124509](https://github.com/rust-lang/rust/issues/124509).
------------------------
This feature allows you to differentiate functions using automatic differentiation.
Set the `-Zautodiff=<options>` compiler flag to adjust the behaviour of the autodiff feature.
Multiple options can be separated with a comma. Valid options are:
`PrintTA` - print Type Analysis Information
`PrintAA` - print Activity Analysis Information
`PrintPerf` - print Performance Warnings from Enzyme
`Print` - prints all intermediate transformations
`PrintModBefore` - print the whole module, before running opts
`PrintModAfterOpts` - print the whole module just before we pass it to Enzyme
`PrintModAfterEnzyme` - print the module after Enzyme differentiated everything
`LooseTypes` - Enzyme's loose type debug helper (can cause incorrect gradients)
`Inline` - runs Enzyme specific Inlining
`NoModOptAfter` - do not optimize the module after Enzyme is done
`EnableFncOpt` - tell Enzyme to run LLVM Opts on each function it generated
`NoVecUnroll` - do not unroll vectorized loops
`RuntimeActivity` - allow specifying activity at runtime

@ -1 +1 @@
Subproject commit 2fe5164a2423dd67ef25e2c4fb204fd06362494b
Subproject commit 0e5fa4a3d475f4dece489c9e06b11164f83789f5