upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
This commit is contained in:
parent
ebcf860e73
commit
1f30517d40
27 changed files with 482 additions and 38 deletions
|
@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro
|
|||
|
||||
codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering
|
||||
|
||||
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
|
||||
|
||||
codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty
|
||||
|
||||
codegen_ssa_cgu_not_recorded =
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
|
|||
use std::{fs, io, mem, str, thread};
|
||||
|
||||
use rustc_ast::attr;
|
||||
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
|
||||
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
|
||||
use rustc_data_structures::jobserver::{self, Acquired};
|
||||
use rustc_data_structures::memmap::Mmap;
|
||||
|
@ -40,7 +41,7 @@ use tracing::debug;
|
|||
use super::link::{self, ensure_removed};
|
||||
use super::lto::{self, SerializedModule};
|
||||
use super::symbol_export::symbol_name_for_instance_in_crate;
|
||||
use crate::errors::ErrorCreatingRemarkDir;
|
||||
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
|
||||
use crate::traits::*;
|
||||
use crate::{
|
||||
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
|
||||
|
@ -118,6 +119,7 @@ pub struct ModuleConfig {
|
|||
pub merge_functions: bool,
|
||||
pub emit_lifetime_markers: bool,
|
||||
pub llvm_plugins: Vec<String>,
|
||||
pub autodiff: Vec<config::AutoDiff>,
|
||||
}
|
||||
|
||||
impl ModuleConfig {
|
||||
|
@ -266,6 +268,7 @@ impl ModuleConfig {
|
|||
|
||||
emit_lifetime_markers: sess.emit_lifetime_markers(),
|
||||
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
|
||||
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
|
|||
|
||||
fn generate_lto_work<B: ExtraBackendMethods>(
|
||||
cgcx: &CodegenContext<B>,
|
||||
autodiff: Vec<AutoDiffItem>,
|
||||
needs_fat_lto: Vec<FatLtoInput<B>>,
|
||||
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
|
||||
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
|
||||
|
@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
|
|||
|
||||
if !needs_fat_lto.is_empty() {
|
||||
assert!(needs_thin_lto.is_empty());
|
||||
let module =
|
||||
let mut module =
|
||||
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
|
||||
if cgcx.lto == Lto::Fat {
|
||||
let config = cgcx.config(ModuleKind::Regular);
|
||||
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
|
||||
}
|
||||
// We are adding a single work item, so the cost doesn't matter.
|
||||
vec![(WorkItem::LTO(module), 0)]
|
||||
} else {
|
||||
if !autodiff.is_empty() {
|
||||
let dcx = cgcx.create_dcx();
|
||||
dcx.handle().emit_fatal(AutodiffWithoutLto {});
|
||||
}
|
||||
assert!(needs_fat_lto.is_empty());
|
||||
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
|
||||
.unwrap_or_else(|e| e.raise());
|
||||
|
@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
|
|||
/// Sent from a backend worker thread.
|
||||
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
|
||||
|
||||
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
|
||||
AddAutoDiffItems(Vec<AutoDiffItem>),
|
||||
|
||||
/// The frontend has finished generating something (backend IR or a
|
||||
/// post-LTO artifact) for a codegen unit, and it should be passed to the
|
||||
/// backend. Sent from the main thread.
|
||||
|
@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
|||
|
||||
// This is where we collect codegen units that have gone all the way
|
||||
// through codegen and LLVM.
|
||||
let mut autodiff_items = Vec::new();
|
||||
let mut compiled_modules = vec![];
|
||||
let mut compiled_allocator_module = None;
|
||||
let mut needs_link = Vec::new();
|
||||
|
@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
|||
let needs_thin_lto = mem::take(&mut needs_thin_lto);
|
||||
let import_only_modules = mem::take(&mut lto_import_only_modules);
|
||||
|
||||
for (work, cost) in
|
||||
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
|
||||
{
|
||||
for (work, cost) in generate_lto_work(
|
||||
&cgcx,
|
||||
autodiff_items.clone(),
|
||||
needs_fat_lto,
|
||||
needs_thin_lto,
|
||||
import_only_modules,
|
||||
) {
|
||||
let insertion_index = work_items
|
||||
.binary_search_by_key(&cost, |&(_, cost)| cost)
|
||||
.unwrap_or_else(|e| e);
|
||||
|
@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
|||
main_thread_state = MainThreadState::Idle;
|
||||
}
|
||||
|
||||
Message::AddAutoDiffItems(mut items) => {
|
||||
autodiff_items.append(&mut items);
|
||||
}
|
||||
|
||||
Message::CodegenComplete => {
|
||||
if codegen_state != Aborted {
|
||||
codegen_state = Completed;
|
||||
|
@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
|
|||
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
|
||||
}
|
||||
|
||||
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
|
||||
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
|
||||
}
|
||||
|
||||
pub(crate) fn check_for_errors(&self, sess: &Session) {
|
||||
self.shared_emitter_main.check(sess, false);
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
|
|||
use rustc_middle::middle::exported_symbols::SymbolExportKind;
|
||||
use rustc_middle::middle::{exported_symbols, lang_items};
|
||||
use rustc_middle::mir::BinOp;
|
||||
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
|
||||
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
|
||||
use rustc_middle::query::Providers;
|
||||
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
|
||||
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
||||
|
@ -619,7 +619,9 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
|
|||
|
||||
// Run the monomorphization collector and partition the collected items into
|
||||
// codegen units.
|
||||
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
|
||||
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
|
||||
tcx.collect_and_partition_mono_items(());
|
||||
let autodiff_fncs = autodiff_items.to_vec();
|
||||
|
||||
// Force all codegen_unit queries so they are already either red or green
|
||||
// when compile_codegen_unit accesses them. We are not able to re-execute
|
||||
|
@ -690,6 +692,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
|
|||
);
|
||||
}
|
||||
|
||||
if !autodiff_fncs.is_empty() {
|
||||
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
|
||||
}
|
||||
|
||||
// For better throughput during parallel processing by LLVM, we used to sort
|
||||
// CGUs largest to smallest. This would lead to better thread utilization
|
||||
// by, for example, preventing a large CGU from being processed last and
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use rustc_ast::attr::list_contains_name;
|
||||
use rustc_ast::{MetaItemInner, attr};
|
||||
use rustc_ast::expand::autodiff_attrs::{
|
||||
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
|
||||
};
|
||||
use rustc_ast::{MetaItem, MetaItemInner, attr};
|
||||
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_errors::codes::*;
|
||||
|
@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
|
|||
};
|
||||
use rustc_middle::mir::mono::Linkage;
|
||||
use rustc_middle::query::Providers;
|
||||
use rustc_middle::span_bug;
|
||||
use rustc_middle::ty::{self as ty, TyCtxt};
|
||||
use rustc_session::parse::feature_err;
|
||||
use rustc_session::{Session, lint};
|
||||
|
@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
|
|||
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
|
||||
}
|
||||
|
||||
// If our rustc version supports autodiff/enzyme, then we call our handler
|
||||
// to check for any `#[rustc_autodiff(...)]` attributes.
|
||||
if cfg!(llvm_enzyme) {
|
||||
let ad = autodiff_attrs(tcx, did.into());
|
||||
codegen_fn_attrs.autodiff_item = ad;
|
||||
}
|
||||
|
||||
// When `no_builtins` is applied at the crate level, we should add the
|
||||
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
|
||||
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
|
||||
|
@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
|
||||
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
|
||||
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
|
||||
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
|
||||
/// panic, unless we introduced a bug when parsing the autodiff macro.
|
||||
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
||||
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
|
||||
|
||||
let attrs =
|
||||
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
|
||||
|
||||
// check for exactly one autodiff attribute on placeholder functions.
|
||||
// There should only be one, since we generate a new placeholder per ad macro.
|
||||
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
|
||||
// looks strange e.g. under cargo-expand.
|
||||
let attr = match &attrs[..] {
|
||||
[] => return None,
|
||||
[attr] => attr,
|
||||
// These two attributes are the same and unfortunately duplicated due to a previous bug.
|
||||
[attr, _attr2] => attr,
|
||||
_ => {
|
||||
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
|
||||
//branch above.
|
||||
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
|
||||
}
|
||||
};
|
||||
|
||||
let list = attr.meta_item_list().unwrap_or_default();
|
||||
|
||||
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
|
||||
if list.is_empty() {
|
||||
return Some(AutoDiffAttrs::source());
|
||||
}
|
||||
|
||||
let [mode, input_activities @ .., ret_activity] = &list[..] else {
|
||||
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
|
||||
};
|
||||
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
|
||||
p1.segments.first().unwrap().ident
|
||||
} else {
|
||||
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
|
||||
};
|
||||
|
||||
// parse mode
|
||||
let mode = match mode.as_str() {
|
||||
"Forward" => DiffMode::Forward,
|
||||
"Reverse" => DiffMode::Reverse,
|
||||
"ForwardFirst" => DiffMode::ForwardFirst,
|
||||
"ReverseFirst" => DiffMode::ReverseFirst,
|
||||
_ => {
|
||||
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
|
||||
}
|
||||
};
|
||||
|
||||
// First read the ret symbol from the attribute
|
||||
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
|
||||
p1.segments.first().unwrap().ident
|
||||
} else {
|
||||
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
|
||||
};
|
||||
|
||||
// Then parse it into an actual DiffActivity
|
||||
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
|
||||
span_bug!(ret_symbol.span, "invalid return activity");
|
||||
};
|
||||
|
||||
// Now parse all the intermediate (input) activities
|
||||
let mut arg_activities: Vec<DiffActivity> = vec![];
|
||||
for arg in input_activities {
|
||||
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
|
||||
match p2.segments.first() {
|
||||
Some(x) => x.ident,
|
||||
None => {
|
||||
span_bug!(
|
||||
arg.span(),
|
||||
"rustc_autodiff attribute must contain the input activity"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
|
||||
};
|
||||
|
||||
match DiffActivity::from_str(arg_symbol.as_str()) {
|
||||
Ok(arg_activity) => arg_activities.push(arg_activity),
|
||||
Err(_) => {
|
||||
span_bug!(arg_symbol.span, "invalid input activity");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for &input in &arg_activities {
|
||||
if !valid_input_activity(mode, input) {
|
||||
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
|
||||
}
|
||||
}
|
||||
if !valid_ret_activity(mode, ret_activity) {
|
||||
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
|
||||
}
|
||||
|
||||
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
|
||||
}
|
||||
|
||||
pub(crate) fn provide(providers: &mut Providers) {
|
||||
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
|
||||
}
|
||||
|
|
|
@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
|
|||
pub cgu_name: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(codegen_ssa_autodiff_without_lto)]
|
||||
pub struct AutodiffWithoutLto;
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(codegen_ssa_unknown_reuse_kind)]
|
||||
pub(crate) struct UnknownReuseKind {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue