upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
This commit is contained in:
parent
ebcf860e73
commit
1f30517d40
27 changed files with 482 additions and 38 deletions
|
@ -4234,6 +4234,7 @@ name = "rustc_monomorphize"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rustc_abi",
|
"rustc_abi",
|
||||||
|
"rustc_ast",
|
||||||
"rustc_attr_parsing",
|
"rustc_attr_parsing",
|
||||||
"rustc_data_structures",
|
"rustc_data_structures",
|
||||||
"rustc_errors",
|
"rustc_errors",
|
||||||
|
@ -4243,6 +4244,7 @@ dependencies = [
|
||||||
"rustc_middle",
|
"rustc_middle",
|
||||||
"rustc_session",
|
"rustc_session",
|
||||||
"rustc_span",
|
"rustc_span",
|
||||||
|
"rustc_symbol_mangling",
|
||||||
"rustc_target",
|
"rustc_target",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -79,6 +79,7 @@ pub struct AutoDiffItem {
|
||||||
pub target: String,
|
pub target: String,
|
||||||
pub attrs: AutoDiffAttrs,
|
pub attrs: AutoDiffAttrs,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
pub struct AutoDiffAttrs {
|
pub struct AutoDiffAttrs {
|
||||||
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
|
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
|
||||||
|
@ -231,7 +232,7 @@ impl AutoDiffAttrs {
|
||||||
self.ret_activity == DiffActivity::ActiveOnly
|
self.ret_activity == DiffActivity::ActiveOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn error() -> Self {
|
pub const fn error() -> Self {
|
||||||
AutoDiffAttrs {
|
AutoDiffAttrs {
|
||||||
mode: DiffMode::Error,
|
mode: DiffMode::Error,
|
||||||
ret_activity: DiffActivity::None,
|
ret_activity: DiffActivity::None,
|
||||||
|
|
|
@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>(
|
||||||
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
|
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
|
||||||
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
|
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
|
||||||
let name = llvm::get_value_name(outer_fn);
|
let name = llvm::get_value_name(outer_fn);
|
||||||
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
|
let outer_fn_name = std::str::from_utf8(name).unwrap();
|
||||||
ad_name.push_str(outer_fn_name.to_string().as_str());
|
ad_name.push_str(outer_fn_name);
|
||||||
|
|
||||||
// Let us assume the user wrote the following function square:
|
// Let us assume the user wrote the following function square:
|
||||||
//
|
//
|
||||||
|
@ -255,14 +255,14 @@ fn generate_enzyme_call<'ll>(
|
||||||
// have no debug info to copy, which would then be ok.
|
// have no debug info to copy, which would then be ok.
|
||||||
trace!("no dbg info");
|
trace!("no dbg info");
|
||||||
}
|
}
|
||||||
// Now that we copied the metadata, get rid of dummy code.
|
|
||||||
llvm::LLVMRustEraseInstBefore(entry, last_inst);
|
|
||||||
llvm::LLVMRustEraseInstFromParent(last_inst);
|
|
||||||
|
|
||||||
if cx.val_ty(outer_fn) != cx.type_void() {
|
// Now that we copied the metadata, get rid of dummy code.
|
||||||
builder.ret(call);
|
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
|
||||||
} else {
|
|
||||||
|
if cx.val_ty(call) == cx.type_void() {
|
||||||
builder.ret_void();
|
builder.ret_void();
|
||||||
|
} else {
|
||||||
|
builder.ret(call);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's crash in case that we messed something up above and generated invalid IR.
|
// Let's crash in case that we messed something up above and generated invalid IR.
|
||||||
|
|
|
@ -298,7 +298,7 @@ struct UsageSets<'tcx> {
|
||||||
/// Prepare sets of definitions that are relevant to deciding whether something
|
/// Prepare sets of definitions that are relevant to deciding whether something
|
||||||
/// is an "unused function" for coverage purposes.
|
/// is an "unused function" for coverage purposes.
|
||||||
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
|
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
|
||||||
let MonoItemPartitions { all_mono_items, codegen_units } =
|
let MonoItemPartitions { all_mono_items, codegen_units, .. } =
|
||||||
tcx.collect_and_partition_mono_items(());
|
tcx.collect_and_partition_mono_items(());
|
||||||
|
|
||||||
// Obtain a MIR body for each function participating in codegen, via an
|
// Obtain a MIR body for each function participating in codegen, via an
|
||||||
|
|
|
@ -7,11 +7,13 @@ use crate::llvm::Bool;
|
||||||
extern "C" {
|
extern "C" {
|
||||||
// Enzyme
|
// Enzyme
|
||||||
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
|
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
|
||||||
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
|
pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
|
||||||
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
|
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
|
||||||
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
|
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
|
||||||
pub fn LLVMRustEraseInstFromParent(V: &Value);
|
pub fn LLVMRustEraseInstFromParent(V: &Value);
|
||||||
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
|
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
|
||||||
|
pub fn LLVMDumpModule(M: &Module);
|
||||||
|
pub fn LLVMDumpValue(V: &Value);
|
||||||
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
|
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
|
||||||
|
|
||||||
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
|
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
|
||||||
|
|
|
@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro
|
||||||
|
|
||||||
codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering
|
codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering
|
||||||
|
|
||||||
|
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
|
||||||
|
|
||||||
codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty
|
codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty
|
||||||
|
|
||||||
codegen_ssa_cgu_not_recorded =
|
codegen_ssa_cgu_not_recorded =
|
||||||
|
|
|
@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
|
||||||
use std::{fs, io, mem, str, thread};
|
use std::{fs, io, mem, str, thread};
|
||||||
|
|
||||||
use rustc_ast::attr;
|
use rustc_ast::attr;
|
||||||
|
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
|
||||||
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
|
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
|
||||||
use rustc_data_structures::jobserver::{self, Acquired};
|
use rustc_data_structures::jobserver::{self, Acquired};
|
||||||
use rustc_data_structures::memmap::Mmap;
|
use rustc_data_structures::memmap::Mmap;
|
||||||
|
@ -40,7 +41,7 @@ use tracing::debug;
|
||||||
use super::link::{self, ensure_removed};
|
use super::link::{self, ensure_removed};
|
||||||
use super::lto::{self, SerializedModule};
|
use super::lto::{self, SerializedModule};
|
||||||
use super::symbol_export::symbol_name_for_instance_in_crate;
|
use super::symbol_export::symbol_name_for_instance_in_crate;
|
||||||
use crate::errors::ErrorCreatingRemarkDir;
|
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
|
||||||
use crate::traits::*;
|
use crate::traits::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
|
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
|
||||||
|
@ -118,6 +119,7 @@ pub struct ModuleConfig {
|
||||||
pub merge_functions: bool,
|
pub merge_functions: bool,
|
||||||
pub emit_lifetime_markers: bool,
|
pub emit_lifetime_markers: bool,
|
||||||
pub llvm_plugins: Vec<String>,
|
pub llvm_plugins: Vec<String>,
|
||||||
|
pub autodiff: Vec<config::AutoDiff>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModuleConfig {
|
impl ModuleConfig {
|
||||||
|
@ -266,6 +268,7 @@ impl ModuleConfig {
|
||||||
|
|
||||||
emit_lifetime_markers: sess.emit_lifetime_markers(),
|
emit_lifetime_markers: sess.emit_lifetime_markers(),
|
||||||
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
|
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
|
||||||
|
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
|
||||||
|
|
||||||
fn generate_lto_work<B: ExtraBackendMethods>(
|
fn generate_lto_work<B: ExtraBackendMethods>(
|
||||||
cgcx: &CodegenContext<B>,
|
cgcx: &CodegenContext<B>,
|
||||||
|
autodiff: Vec<AutoDiffItem>,
|
||||||
needs_fat_lto: Vec<FatLtoInput<B>>,
|
needs_fat_lto: Vec<FatLtoInput<B>>,
|
||||||
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
|
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
|
||||||
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
|
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
|
||||||
|
@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
|
||||||
|
|
||||||
if !needs_fat_lto.is_empty() {
|
if !needs_fat_lto.is_empty() {
|
||||||
assert!(needs_thin_lto.is_empty());
|
assert!(needs_thin_lto.is_empty());
|
||||||
let module =
|
let mut module =
|
||||||
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
|
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
|
||||||
|
if cgcx.lto == Lto::Fat {
|
||||||
|
let config = cgcx.config(ModuleKind::Regular);
|
||||||
|
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
|
||||||
|
}
|
||||||
// We are adding a single work item, so the cost doesn't matter.
|
// We are adding a single work item, so the cost doesn't matter.
|
||||||
vec![(WorkItem::LTO(module), 0)]
|
vec![(WorkItem::LTO(module), 0)]
|
||||||
} else {
|
} else {
|
||||||
|
if !autodiff.is_empty() {
|
||||||
|
let dcx = cgcx.create_dcx();
|
||||||
|
dcx.handle().emit_fatal(AutodiffWithoutLto {});
|
||||||
|
}
|
||||||
assert!(needs_fat_lto.is_empty());
|
assert!(needs_fat_lto.is_empty());
|
||||||
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
|
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
|
||||||
.unwrap_or_else(|e| e.raise());
|
.unwrap_or_else(|e| e.raise());
|
||||||
|
@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
|
||||||
/// Sent from a backend worker thread.
|
/// Sent from a backend worker thread.
|
||||||
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
|
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
|
||||||
|
|
||||||
|
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
|
||||||
|
AddAutoDiffItems(Vec<AutoDiffItem>),
|
||||||
|
|
||||||
/// The frontend has finished generating something (backend IR or a
|
/// The frontend has finished generating something (backend IR or a
|
||||||
/// post-LTO artifact) for a codegen unit, and it should be passed to the
|
/// post-LTO artifact) for a codegen unit, and it should be passed to the
|
||||||
/// backend. Sent from the main thread.
|
/// backend. Sent from the main thread.
|
||||||
|
@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||||
|
|
||||||
// This is where we collect codegen units that have gone all the way
|
// This is where we collect codegen units that have gone all the way
|
||||||
// through codegen and LLVM.
|
// through codegen and LLVM.
|
||||||
|
let mut autodiff_items = Vec::new();
|
||||||
let mut compiled_modules = vec![];
|
let mut compiled_modules = vec![];
|
||||||
let mut compiled_allocator_module = None;
|
let mut compiled_allocator_module = None;
|
||||||
let mut needs_link = Vec::new();
|
let mut needs_link = Vec::new();
|
||||||
|
@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||||
let needs_thin_lto = mem::take(&mut needs_thin_lto);
|
let needs_thin_lto = mem::take(&mut needs_thin_lto);
|
||||||
let import_only_modules = mem::take(&mut lto_import_only_modules);
|
let import_only_modules = mem::take(&mut lto_import_only_modules);
|
||||||
|
|
||||||
for (work, cost) in
|
for (work, cost) in generate_lto_work(
|
||||||
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
|
&cgcx,
|
||||||
{
|
autodiff_items.clone(),
|
||||||
|
needs_fat_lto,
|
||||||
|
needs_thin_lto,
|
||||||
|
import_only_modules,
|
||||||
|
) {
|
||||||
let insertion_index = work_items
|
let insertion_index = work_items
|
||||||
.binary_search_by_key(&cost, |&(_, cost)| cost)
|
.binary_search_by_key(&cost, |&(_, cost)| cost)
|
||||||
.unwrap_or_else(|e| e);
|
.unwrap_or_else(|e| e);
|
||||||
|
@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
|
||||||
main_thread_state = MainThreadState::Idle;
|
main_thread_state = MainThreadState::Idle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Message::AddAutoDiffItems(mut items) => {
|
||||||
|
autodiff_items.append(&mut items);
|
||||||
|
}
|
||||||
|
|
||||||
Message::CodegenComplete => {
|
Message::CodegenComplete => {
|
||||||
if codegen_state != Aborted {
|
if codegen_state != Aborted {
|
||||||
codegen_state = Completed;
|
codegen_state = Completed;
|
||||||
|
@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
|
||||||
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
|
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
|
||||||
|
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn check_for_errors(&self, sess: &Session) {
|
pub(crate) fn check_for_errors(&self, sess: &Session) {
|
||||||
self.shared_emitter_main.check(sess, false);
|
self.shared_emitter_main.check(sess, false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
|
||||||
use rustc_middle::middle::exported_symbols::SymbolExportKind;
|
use rustc_middle::middle::exported_symbols::SymbolExportKind;
|
||||||
use rustc_middle::middle::{exported_symbols, lang_items};
|
use rustc_middle::middle::{exported_symbols, lang_items};
|
||||||
use rustc_middle::mir::BinOp;
|
use rustc_middle::mir::BinOp;
|
||||||
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
|
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
|
||||||
use rustc_middle::query::Providers;
|
use rustc_middle::query::Providers;
|
||||||
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
|
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
|
||||||
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
||||||
|
@ -619,7 +619,9 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
|
||||||
|
|
||||||
// Run the monomorphization collector and partition the collected items into
|
// Run the monomorphization collector and partition the collected items into
|
||||||
// codegen units.
|
// codegen units.
|
||||||
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
|
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
|
||||||
|
tcx.collect_and_partition_mono_items(());
|
||||||
|
let autodiff_fncs = autodiff_items.to_vec();
|
||||||
|
|
||||||
// Force all codegen_unit queries so they are already either red or green
|
// Force all codegen_unit queries so they are already either red or green
|
||||||
// when compile_codegen_unit accesses them. We are not able to re-execute
|
// when compile_codegen_unit accesses them. We are not able to re-execute
|
||||||
|
@ -690,6 +692,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !autodiff_fncs.is_empty() {
|
||||||
|
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
|
||||||
|
}
|
||||||
|
|
||||||
// For better throughput during parallel processing by LLVM, we used to sort
|
// For better throughput during parallel processing by LLVM, we used to sort
|
||||||
// CGUs largest to smallest. This would lead to better thread utilization
|
// CGUs largest to smallest. This would lead to better thread utilization
|
||||||
// by, for example, preventing a large CGU from being processed last and
|
// by, for example, preventing a large CGU from being processed last and
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
use rustc_ast::attr::list_contains_name;
|
use rustc_ast::attr::list_contains_name;
|
||||||
use rustc_ast::{MetaItemInner, attr};
|
use rustc_ast::expand::autodiff_attrs::{
|
||||||
|
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
|
||||||
|
};
|
||||||
|
use rustc_ast::{MetaItem, MetaItemInner, attr};
|
||||||
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||||
use rustc_data_structures::fx::FxHashMap;
|
use rustc_data_structures::fx::FxHashMap;
|
||||||
use rustc_errors::codes::*;
|
use rustc_errors::codes::*;
|
||||||
|
@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
|
||||||
};
|
};
|
||||||
use rustc_middle::mir::mono::Linkage;
|
use rustc_middle::mir::mono::Linkage;
|
||||||
use rustc_middle::query::Providers;
|
use rustc_middle::query::Providers;
|
||||||
|
use rustc_middle::span_bug;
|
||||||
use rustc_middle::ty::{self as ty, TyCtxt};
|
use rustc_middle::ty::{self as ty, TyCtxt};
|
||||||
use rustc_session::parse::feature_err;
|
use rustc_session::parse::feature_err;
|
||||||
use rustc_session::{Session, lint};
|
use rustc_session::{Session, lint};
|
||||||
|
@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
|
||||||
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
|
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If our rustc version supports autodiff/enzyme, then we call our handler
|
||||||
|
// to check for any `#[rustc_autodiff(...)]` attributes.
|
||||||
|
if cfg!(llvm_enzyme) {
|
||||||
|
let ad = autodiff_attrs(tcx, did.into());
|
||||||
|
codegen_fn_attrs.autodiff_item = ad;
|
||||||
|
}
|
||||||
|
|
||||||
// When `no_builtins` is applied at the crate level, we should add the
|
// When `no_builtins` is applied at the crate level, we should add the
|
||||||
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
|
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
|
||||||
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
|
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
|
||||||
|
@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
|
||||||
|
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
|
||||||
|
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
|
||||||
|
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
|
||||||
|
/// panic, unless we introduced a bug when parsing the autodiff macro.
|
||||||
|
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
|
||||||
|
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
|
||||||
|
|
||||||
|
let attrs =
|
||||||
|
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// check for exactly one autodiff attribute on placeholder functions.
|
||||||
|
// There should only be one, since we generate a new placeholder per ad macro.
|
||||||
|
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
|
||||||
|
// looks strange e.g. under cargo-expand.
|
||||||
|
let attr = match &attrs[..] {
|
||||||
|
[] => return None,
|
||||||
|
[attr] => attr,
|
||||||
|
// These two attributes are the same and unfortunately duplicated due to a previous bug.
|
||||||
|
[attr, _attr2] => attr,
|
||||||
|
_ => {
|
||||||
|
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
|
||||||
|
//branch above.
|
||||||
|
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let list = attr.meta_item_list().unwrap_or_default();
|
||||||
|
|
||||||
|
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
|
||||||
|
if list.is_empty() {
|
||||||
|
return Some(AutoDiffAttrs::source());
|
||||||
|
}
|
||||||
|
|
||||||
|
let [mode, input_activities @ .., ret_activity] = &list[..] else {
|
||||||
|
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
|
||||||
|
};
|
||||||
|
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
|
||||||
|
p1.segments.first().unwrap().ident
|
||||||
|
} else {
|
||||||
|
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
|
||||||
|
};
|
||||||
|
|
||||||
|
// parse mode
|
||||||
|
let mode = match mode.as_str() {
|
||||||
|
"Forward" => DiffMode::Forward,
|
||||||
|
"Reverse" => DiffMode::Reverse,
|
||||||
|
"ForwardFirst" => DiffMode::ForwardFirst,
|
||||||
|
"ReverseFirst" => DiffMode::ReverseFirst,
|
||||||
|
_ => {
|
||||||
|
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// First read the ret symbol from the attribute
|
||||||
|
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
|
||||||
|
p1.segments.first().unwrap().ident
|
||||||
|
} else {
|
||||||
|
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Then parse it into an actual DiffActivity
|
||||||
|
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
|
||||||
|
span_bug!(ret_symbol.span, "invalid return activity");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Now parse all the intermediate (input) activities
|
||||||
|
let mut arg_activities: Vec<DiffActivity> = vec![];
|
||||||
|
for arg in input_activities {
|
||||||
|
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
|
||||||
|
match p2.segments.first() {
|
||||||
|
Some(x) => x.ident,
|
||||||
|
None => {
|
||||||
|
span_bug!(
|
||||||
|
arg.span(),
|
||||||
|
"rustc_autodiff attribute must contain the input activity"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
|
||||||
|
};
|
||||||
|
|
||||||
|
match DiffActivity::from_str(arg_symbol.as_str()) {
|
||||||
|
Ok(arg_activity) => arg_activities.push(arg_activity),
|
||||||
|
Err(_) => {
|
||||||
|
span_bug!(arg_symbol.span, "invalid input activity");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for &input in &arg_activities {
|
||||||
|
if !valid_input_activity(mode, input) {
|
||||||
|
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valid_ret_activity(mode, ret_activity) {
|
||||||
|
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn provide(providers: &mut Providers) {
|
pub(crate) fn provide(providers: &mut Providers) {
|
||||||
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
|
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
|
||||||
pub cgu_name: &'a str,
|
pub cgu_name: &'a str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(codegen_ssa_autodiff_without_lto)]
|
||||||
|
pub struct AutodiffWithoutLto;
|
||||||
|
|
||||||
#[derive(Diagnostic)]
|
#[derive(Diagnostic)]
|
||||||
#[diag(codegen_ssa_unknown_reuse_kind)]
|
#[diag(codegen_ssa_unknown_reuse_kind)]
|
||||||
pub(crate) struct UnknownReuseKind {
|
pub(crate) struct UnknownReuseKind {
|
||||||
|
|
|
@ -8,12 +8,12 @@ use rustc_data_structures::profiling::TimePassesFormat;
|
||||||
use rustc_errors::emitter::HumanReadableErrorType;
|
use rustc_errors::emitter::HumanReadableErrorType;
|
||||||
use rustc_errors::{ColorConfig, registry};
|
use rustc_errors::{ColorConfig, registry};
|
||||||
use rustc_session::config::{
|
use rustc_session::config::{
|
||||||
BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel, CoverageOptions,
|
AutoDiff, BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel,
|
||||||
DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation, Externs,
|
CoverageOptions, DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation,
|
||||||
FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage, InstrumentXRay,
|
Externs, FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage,
|
||||||
LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans, NextSolverConfig,
|
InstrumentXRay, LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans,
|
||||||
OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet, Passes,
|
NextSolverConfig, OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet,
|
||||||
PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath,
|
Passes, PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath,
|
||||||
SymbolManglingVersion, WasiExecModel, build_configuration, build_session_options,
|
SymbolManglingVersion, WasiExecModel, build_configuration, build_session_options,
|
||||||
rustc_optgroups,
|
rustc_optgroups,
|
||||||
};
|
};
|
||||||
|
@ -760,6 +760,7 @@ fn test_unstable_options_tracking_hash() {
|
||||||
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
||||||
tracked!(always_encode_mir, true);
|
tracked!(always_encode_mir, true);
|
||||||
tracked!(assume_incomplete_release, true);
|
tracked!(assume_incomplete_release, true);
|
||||||
|
tracked!(autodiff, vec![AutoDiff::Print]);
|
||||||
tracked!(binary_dep_depinfo, true);
|
tracked!(binary_dep_depinfo, true);
|
||||||
tracked!(box_noalias, false);
|
tracked!(box_noalias, false);
|
||||||
tracked!(
|
tracked!(
|
||||||
|
|
|
@ -955,7 +955,8 @@ extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
|
extern "C" void LLVMRustEraseInstUntilInclusive(LLVMBasicBlockRef bb,
|
||||||
|
LLVMValueRef I) {
|
||||||
auto &BB = *unwrap(bb);
|
auto &BB = *unwrap(bb);
|
||||||
auto &Inst = *unwrap<Instruction>(I);
|
auto &Inst = *unwrap<Instruction>(I);
|
||||||
auto It = BB.begin();
|
auto It = BB.begin();
|
||||||
|
@ -963,8 +964,6 @@ extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
|
||||||
++It;
|
++It;
|
||||||
// Make sure we found the Instruction.
|
// Make sure we found the Instruction.
|
||||||
assert(It != BB.end());
|
assert(It != BB.end());
|
||||||
// We don't want to erase the instruction itself.
|
|
||||||
It--;
|
|
||||||
// Delete in rev order to ensure no dangling references.
|
// Delete in rev order to ensure no dangling references.
|
||||||
while (It != BB.begin()) {
|
while (It != BB.begin()) {
|
||||||
auto Prev = std::prev(It);
|
auto Prev = std::prev(It);
|
||||||
|
|
|
@ -32,6 +32,8 @@ middle_assert_shl_overflow =
|
||||||
middle_assert_shr_overflow =
|
middle_assert_shr_overflow =
|
||||||
attempt to shift right by `{$val}`, which would overflow
|
attempt to shift right by `{$val}`, which would overflow
|
||||||
|
|
||||||
|
middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe
|
||||||
|
|
||||||
middle_bounds_check =
|
middle_bounds_check =
|
||||||
index out of bounds: the length is {$len} but the index is {$index}
|
index out of bounds: the length is {$len} but the index is {$index}
|
||||||
|
|
||||||
|
@ -107,6 +109,8 @@ middle_type_length_limit = reached the type-length limit while instantiating `{$
|
||||||
middle_unknown_layout =
|
middle_unknown_layout =
|
||||||
the type `{$ty}` has an unknown layout
|
the type `{$ty}` has an unknown layout
|
||||||
|
|
||||||
|
middle_unsupported_union = we don't support unions yet: '{$ty_name}'
|
||||||
|
|
||||||
middle_values_too_big =
|
middle_values_too_big =
|
||||||
values of the type `{$ty}` are too big for the target architecture
|
values of the type `{$ty}` are too big for the target architecture
|
||||||
|
|
||||||
|
|
|
@ -87,6 +87,7 @@ macro_rules! arena_types {
|
||||||
[] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>,
|
[] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>,
|
||||||
[decode] attribute: rustc_hir::Attribute,
|
[decode] attribute: rustc_hir::Attribute,
|
||||||
[] name_set: rustc_data_structures::unord::UnordSet<rustc_span::Symbol>,
|
[] name_set: rustc_data_structures::unord::UnordSet<rustc_span::Symbol>,
|
||||||
|
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
|
||||||
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
|
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
|
||||||
[] pats: rustc_middle::ty::PatternKind<'tcx>,
|
[] pats: rustc_middle::ty::PatternKind<'tcx>,
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,20 @@ pub struct OpaqueHiddenTypeMismatch<'tcx> {
|
||||||
pub sub: TypeMismatchReason,
|
pub sub: TypeMismatchReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(middle_unsupported_union)]
|
||||||
|
pub struct UnsupportedUnion {
|
||||||
|
pub ty_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(middle_autodiff_unsafe_inner_const_ref)]
|
||||||
|
pub struct AutodiffUnsafeInnerConstRef {
|
||||||
|
#[primary_span]
|
||||||
|
pub span: Span,
|
||||||
|
pub ty: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Subdiagnostic)]
|
#[derive(Subdiagnostic)]
|
||||||
pub enum TypeMismatchReason {
|
pub enum TypeMismatchReason {
|
||||||
#[label(middle_conflict_types)]
|
#[label(middle_conflict_types)]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use rustc_abi::Align;
|
use rustc_abi::Align;
|
||||||
|
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
|
||||||
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||||
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
|
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
|
||||||
use rustc_span::Symbol;
|
use rustc_span::Symbol;
|
||||||
|
@ -52,6 +53,8 @@ pub struct CodegenFnAttrs {
|
||||||
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
|
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
|
||||||
/// the function entry.
|
/// the function entry.
|
||||||
pub patchable_function_entry: Option<PatchableFunctionEntry>,
|
pub patchable_function_entry: Option<PatchableFunctionEntry>,
|
||||||
|
/// For the `#[autodiff]` macros.
|
||||||
|
pub autodiff_item: Option<AutoDiffAttrs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
|
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
|
||||||
|
@ -160,6 +163,7 @@ impl CodegenFnAttrs {
|
||||||
instruction_set: None,
|
instruction_set: None,
|
||||||
alignment: None,
|
alignment: None,
|
||||||
patchable_function_entry: None,
|
patchable_function_entry: None,
|
||||||
|
autodiff_item: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
|
||||||
use rustc_attr_parsing::InlineAttr;
|
use rustc_attr_parsing::InlineAttr;
|
||||||
use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN};
|
use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN};
|
||||||
use rustc_data_structures::fingerprint::Fingerprint;
|
use rustc_data_structures::fingerprint::Fingerprint;
|
||||||
|
@ -251,6 +252,7 @@ impl ToStableHashKey<StableHashingContext<'_>> for MonoItem<'_> {
|
||||||
pub struct MonoItemPartitions<'tcx> {
|
pub struct MonoItemPartitions<'tcx> {
|
||||||
pub codegen_units: &'tcx [CodegenUnit<'tcx>],
|
pub codegen_units: &'tcx [CodegenUnit<'tcx>],
|
||||||
pub all_mono_items: &'tcx DefIdSet,
|
pub all_mono_items: &'tcx DefIdSet,
|
||||||
|
pub autodiff_items: &'tcx [AutoDiffItem],
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, HashStable)]
|
#[derive(Debug, HashStable)]
|
||||||
|
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# tidy-alphabetical-start
|
# tidy-alphabetical-start
|
||||||
rustc_abi = { path = "../rustc_abi" }
|
rustc_abi = { path = "../rustc_abi" }
|
||||||
|
rustc_ast = { path = "../rustc_ast" }
|
||||||
rustc_attr_parsing = { path = "../rustc_attr_parsing" }
|
rustc_attr_parsing = { path = "../rustc_attr_parsing" }
|
||||||
rustc_data_structures = { path = "../rustc_data_structures" }
|
rustc_data_structures = { path = "../rustc_data_structures" }
|
||||||
rustc_errors = { path = "../rustc_errors" }
|
rustc_errors = { path = "../rustc_errors" }
|
||||||
|
@ -15,6 +16,7 @@ rustc_macros = { path = "../rustc_macros" }
|
||||||
rustc_middle = { path = "../rustc_middle" }
|
rustc_middle = { path = "../rustc_middle" }
|
||||||
rustc_session = { path = "../rustc_session" }
|
rustc_session = { path = "../rustc_session" }
|
||||||
rustc_span = { path = "../rustc_span" }
|
rustc_span = { path = "../rustc_span" }
|
||||||
|
rustc_symbol_mangling = { path = "../rustc_symbol_mangling" }
|
||||||
rustc_target = { path = "../rustc_target" }
|
rustc_target = { path = "../rustc_target" }
|
||||||
serde = "1"
|
serde = "1"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
|
|
@ -257,7 +257,7 @@ struct SharedState<'tcx> {
|
||||||
|
|
||||||
pub(crate) struct UsageMap<'tcx> {
|
pub(crate) struct UsageMap<'tcx> {
|
||||||
// Maps every mono item to the mono items used by it.
|
// Maps every mono item to the mono items used by it.
|
||||||
used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||||
|
|
||||||
// Maps every mono item to the mono items that use it.
|
// Maps every mono item to the mono items that use it.
|
||||||
user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||||
|
|
|
@ -92,6 +92,8 @@
|
||||||
//! source-level module, functions from the same module will be available for
|
//! source-level module, functions from the same module will be available for
|
||||||
//! inlining, even when they are not marked `#[inline]`.
|
//! inlining, even when they are not marked `#[inline]`.
|
||||||
|
|
||||||
|
mod autodiff;
|
||||||
|
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
use std::fs::{self, File};
|
use std::fs::{self, File};
|
||||||
|
@ -251,7 +253,17 @@ where
|
||||||
can_export_generics,
|
can_export_generics,
|
||||||
always_export_generics,
|
always_export_generics,
|
||||||
);
|
);
|
||||||
if visibility == Visibility::Hidden && can_be_internalized {
|
|
||||||
|
// We can't differentiate something that got inlined.
|
||||||
|
let autodiff_active = cfg!(llvm_enzyme)
|
||||||
|
&& cx
|
||||||
|
.tcx
|
||||||
|
.codegen_fn_attrs(mono_item.def_id())
|
||||||
|
.autodiff_item
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|ad| ad.is_active());
|
||||||
|
|
||||||
|
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
|
||||||
internalization_candidates.insert(mono_item);
|
internalization_candidates.insert(mono_item);
|
||||||
}
|
}
|
||||||
let size_estimate = mono_item.size_estimate(cx.tcx);
|
let size_estimate = mono_item.size_estimate(cx.tcx);
|
||||||
|
@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let autodiff_mono_items: Vec<_> = items
|
||||||
|
.iter()
|
||||||
|
.filter_map(|item| match *item {
|
||||||
|
MonoItem::Fn(ref instance) => Some((item, instance)),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let autodiff_items =
|
||||||
|
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
|
||||||
|
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
|
||||||
|
|
||||||
// Output monomorphization stats per def_id
|
// Output monomorphization stats per def_id
|
||||||
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
|
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
|
||||||
if let Err(err) =
|
if let Err(err) =
|
||||||
|
@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
|
MonoItemPartitions {
|
||||||
|
all_mono_items: tcx.arena.alloc(mono_items),
|
||||||
|
codegen_units,
|
||||||
|
autodiff_items,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
|
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
|
||||||
|
|
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
|
||||||
|
use rustc_hir::def_id::LOCAL_CRATE;
|
||||||
|
use rustc_middle::bug;
|
||||||
|
use rustc_middle::mir::mono::MonoItem;
|
||||||
|
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
||||||
|
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
|
||||||
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
|
use crate::partitioning::UsageMap;
|
||||||
|
|
||||||
|
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
|
||||||
|
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||||
|
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||||
|
}
|
||||||
|
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
|
||||||
|
|
||||||
|
// If rustc compiles the unmodified primal, we know that this copy of the function
|
||||||
|
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
|
||||||
|
// (or actually at all), so let's strip lifetimes when computing the layout.
|
||||||
|
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
|
||||||
|
let mut new_activities = vec![];
|
||||||
|
let mut new_positions = vec![];
|
||||||
|
for (i, ty) in x.inputs().iter().enumerate() {
|
||||||
|
if let Some(inner_ty) = ty.builtin_deref(true) {
|
||||||
|
if ty.is_fn_ptr() {
|
||||||
|
// FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
|
||||||
|
// since Enzyme itself can handle them.
|
||||||
|
tcx.dcx().err("function pointers are currently not supported in autodiff");
|
||||||
|
}
|
||||||
|
if inner_ty.is_slice() {
|
||||||
|
// We know that the length will be passed as extra arg.
|
||||||
|
if !da.is_empty() {
|
||||||
|
// We are looking at a slice. The length of that slice will become an
|
||||||
|
// extra integer on llvm level. Integers are always const.
|
||||||
|
// However, if the slice get's duplicated, we want to know to later check the
|
||||||
|
// size. So we mark the new size argument as FakeActivitySize.
|
||||||
|
let activity = match da[i] {
|
||||||
|
DiffActivity::DualOnly
|
||||||
|
| DiffActivity::Dual
|
||||||
|
| DiffActivity::DuplicatedOnly
|
||||||
|
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
|
||||||
|
DiffActivity::Const => DiffActivity::Const,
|
||||||
|
_ => bug!("unexpected activity for ptr/ref"),
|
||||||
|
};
|
||||||
|
new_activities.push(activity);
|
||||||
|
new_positions.push(i + 1);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// now add the extra activities coming from slices
|
||||||
|
// Reverse order to not invalidate the indices
|
||||||
|
for _ in 0..new_activities.len() {
|
||||||
|
let pos = new_positions.pop().unwrap();
|
||||||
|
let activity = new_activities.pop().unwrap();
|
||||||
|
da.insert(pos, activity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn find_autodiff_source_functions<'tcx>(
|
||||||
|
tcx: TyCtxt<'tcx>,
|
||||||
|
usage_map: &UsageMap<'tcx>,
|
||||||
|
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
|
||||||
|
) -> Vec<AutoDiffItem> {
|
||||||
|
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
|
||||||
|
for (item, instance) in autodiff_mono_items {
|
||||||
|
let target_id = instance.def_id();
|
||||||
|
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
|
||||||
|
let Some(target_attrs) = cg_fn_attr else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
|
||||||
|
if target_attrs.is_source() {
|
||||||
|
trace!("source found: {:?}", target_id);
|
||||||
|
}
|
||||||
|
if !target_attrs.apply_autodiff() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
|
||||||
|
|
||||||
|
let source =
|
||||||
|
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
|
||||||
|
MonoItem::Fn(ref instance_s) => {
|
||||||
|
let source_id = instance_s.def_id();
|
||||||
|
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
|
||||||
|
&& ad.is_active()
|
||||||
|
{
|
||||||
|
return Some(instance_s);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
});
|
||||||
|
let inst = match source {
|
||||||
|
Some(source) => source,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("source_id: {:?}", inst.def_id());
|
||||||
|
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
|
||||||
|
assert!(fn_ty.is_fn());
|
||||||
|
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
|
||||||
|
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
|
||||||
|
|
||||||
|
let mut new_target_attrs = target_attrs.clone();
|
||||||
|
new_target_attrs.input_activity = input_activities;
|
||||||
|
let itm = new_target_attrs.into_item(symb, target_symbol);
|
||||||
|
autodiff_items.push(itm);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !autodiff_items.is_empty() {
|
||||||
|
trace!("AUTODIFF ITEMS EXIST");
|
||||||
|
for item in &mut *autodiff_items {
|
||||||
|
trace!("{}", &item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
autodiff_items
|
||||||
|
}
|
|
@ -189,6 +189,39 @@ pub enum CoverageLevel {
|
||||||
Mcdc,
|
Mcdc,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The different settings that the `-Z autodiff` flag can have.
|
||||||
|
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
|
||||||
|
pub enum AutoDiff {
|
||||||
|
/// Print TypeAnalysis information
|
||||||
|
PrintTA,
|
||||||
|
/// Print ActivityAnalysis Information
|
||||||
|
PrintAA,
|
||||||
|
/// Print Performance Warnings from Enzyme
|
||||||
|
PrintPerf,
|
||||||
|
/// Combines the three print flags above.
|
||||||
|
Print,
|
||||||
|
/// Print the whole module, before running opts.
|
||||||
|
PrintModBefore,
|
||||||
|
/// Print the whole module just before we pass it to Enzyme.
|
||||||
|
/// For Debug purpose, prefer the OPT flag below
|
||||||
|
PrintModAfterOpts,
|
||||||
|
/// Print the module after Enzyme differentiated everything.
|
||||||
|
PrintModAfterEnzyme,
|
||||||
|
|
||||||
|
/// Enzyme's loose type debug helper (can cause incorrect gradients)
|
||||||
|
LooseTypes,
|
||||||
|
|
||||||
|
/// More flags
|
||||||
|
NoModOptAfter,
|
||||||
|
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
|
||||||
|
/// since we already optimize the whole module after Enzyme is done.
|
||||||
|
EnableFncOpt,
|
||||||
|
NoVecUnroll,
|
||||||
|
RuntimeActivity,
|
||||||
|
/// Runs Enzyme specific Inlining
|
||||||
|
Inline,
|
||||||
|
}
|
||||||
|
|
||||||
/// Settings for `-Z instrument-xray` flag.
|
/// Settings for `-Z instrument-xray` flag.
|
||||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
|
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
|
||||||
pub struct InstrumentXRay {
|
pub struct InstrumentXRay {
|
||||||
|
@ -2902,7 +2935,7 @@ pub(crate) mod dep_tracking {
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
|
AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
|
||||||
CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
|
CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
|
||||||
InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
|
InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
|
||||||
LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
|
LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
|
||||||
|
@ -2950,6 +2983,7 @@ pub(crate) mod dep_tracking {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_dep_tracking_hash_via_hash!(
|
impl_dep_tracking_hash_via_hash!(
|
||||||
|
AutoDiff,
|
||||||
bool,
|
bool,
|
||||||
usize,
|
usize,
|
||||||
NonZero<usize>,
|
NonZero<usize>,
|
||||||
|
|
|
@ -398,6 +398,7 @@ mod desc {
|
||||||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||||
pub(crate) const parse_list_with_polarity: &str =
|
pub(crate) const parse_list_with_polarity: &str =
|
||||||
"a comma-separated list of strings, with elements beginning with + or -";
|
"a comma-separated list of strings, with elements beginning with + or -";
|
||||||
|
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
|
||||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||||
pub(crate) const parse_number: &str = "a number";
|
pub(crate) const parse_number: &str = "a number";
|
||||||
|
@ -1029,6 +1030,38 @@ pub mod parse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
|
||||||
|
let Some(v) = v else {
|
||||||
|
*slot = vec![];
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
let mut v: Vec<&str> = v.split(",").collect();
|
||||||
|
v.sort_unstable();
|
||||||
|
for &val in v.iter() {
|
||||||
|
let variant = match val {
|
||||||
|
"PrintTA" => AutoDiff::PrintTA,
|
||||||
|
"PrintAA" => AutoDiff::PrintAA,
|
||||||
|
"PrintPerf" => AutoDiff::PrintPerf,
|
||||||
|
"Print" => AutoDiff::Print,
|
||||||
|
"PrintModBefore" => AutoDiff::PrintModBefore,
|
||||||
|
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
|
||||||
|
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
|
||||||
|
"LooseTypes" => AutoDiff::LooseTypes,
|
||||||
|
"NoModOptAfter" => AutoDiff::NoModOptAfter,
|
||||||
|
"EnableFncOpt" => AutoDiff::EnableFncOpt,
|
||||||
|
"NoVecUnroll" => AutoDiff::NoVecUnroll,
|
||||||
|
"Inline" => AutoDiff::Inline,
|
||||||
|
_ => {
|
||||||
|
// FIXME(ZuseZ4): print an error saying which value is not recognized
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
slot.push(variant);
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn parse_instrument_coverage(
|
pub(crate) fn parse_instrument_coverage(
|
||||||
slot: &mut InstrumentCoverage,
|
slot: &mut InstrumentCoverage,
|
||||||
v: Option<&str>,
|
v: Option<&str>,
|
||||||
|
@ -1736,6 +1769,22 @@ options! {
|
||||||
either `loaded` or `not-loaded`."),
|
either `loaded` or `not-loaded`."),
|
||||||
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
|
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
|
||||||
"make cfg(version) treat the current version as incomplete (default: no)"),
|
"make cfg(version) treat the current version as incomplete (default: no)"),
|
||||||
|
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
|
||||||
|
"a list of optional autodiff flags to enable
|
||||||
|
Optional extra settings:
|
||||||
|
`=PrintTA`
|
||||||
|
`=PrintAA`
|
||||||
|
`=PrintPerf`
|
||||||
|
`=Print`
|
||||||
|
`=PrintModBefore`
|
||||||
|
`=PrintModAfterOpts`
|
||||||
|
`=PrintModAfterEnzyme`
|
||||||
|
`=LooseTypes`
|
||||||
|
`=NoModOptAfter`
|
||||||
|
`=EnableFncOpt`
|
||||||
|
`=NoVecUnroll`
|
||||||
|
`=Inline`
|
||||||
|
Multiple options can be combined with commas."),
|
||||||
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
|
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
|
||||||
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
|
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
|
||||||
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \
|
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \
|
||||||
|
|
|
@ -502,7 +502,6 @@ symbols! {
|
||||||
augmented_assignments,
|
augmented_assignments,
|
||||||
auto_traits,
|
auto_traits,
|
||||||
autodiff,
|
autodiff,
|
||||||
autodiff_fallback,
|
|
||||||
automatically_derived,
|
automatically_derived,
|
||||||
avx,
|
avx,
|
||||||
avx512_target_feature,
|
avx512_target_feature,
|
||||||
|
@ -568,7 +567,6 @@ symbols! {
|
||||||
cfg_accessible,
|
cfg_accessible,
|
||||||
cfg_attr,
|
cfg_attr,
|
||||||
cfg_attr_multi,
|
cfg_attr_multi,
|
||||||
cfg_autodiff_fallback,
|
|
||||||
cfg_boolean_literals,
|
cfg_boolean_literals,
|
||||||
cfg_doctest,
|
cfg_doctest,
|
||||||
cfg_emscripten_wasm_eh,
|
cfg_emscripten_wasm_eh,
|
||||||
|
|
|
@ -1049,9 +1049,12 @@ pub fn rustc_cargo(
|
||||||
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
|
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
|
||||||
cargo.rustflag("-Zon-broken-pipe=kill");
|
cargo.rustflag("-Zon-broken-pipe=kill");
|
||||||
|
|
||||||
if builder.config.llvm_enzyme {
|
// We temporarily disable linking here as part of some refactoring.
|
||||||
cargo.rustflag("-l").rustflag("Enzyme-19");
|
// This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now.
|
||||||
}
|
// In a follow-up PR, we will re-enable linking here and load the pass for them.
|
||||||
|
//if builder.config.llvm_enzyme {
|
||||||
|
// cargo.rustflag("-l").rustflag("Enzyme-19");
|
||||||
|
//}
|
||||||
|
|
||||||
// Building with protected visibility reduces the number of dynamic relocations needed, giving
|
// Building with protected visibility reduces the number of dynamic relocations needed, giving
|
||||||
// us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object
|
// us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object
|
||||||
|
|
23
src/doc/unstable-book/src/compiler-flags/autodiff.md
Normal file
23
src/doc/unstable-book/src/compiler-flags/autodiff.md
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# `autodiff`
|
||||||
|
|
||||||
|
The tracking issue for this feature is: [#124509](https://github.com/rust-lang/rust/issues/124509).
|
||||||
|
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
This feature allows you to differentiate functions using automatic differentiation.
|
||||||
|
Set the `-Zautodiff=<options>` compiler flag to adjust the behaviour of the autodiff feature.
|
||||||
|
Multiple options can be separated with a comma. Valid options are:
|
||||||
|
|
||||||
|
`PrintTA` - print Type Analysis Information
|
||||||
|
`PrintAA` - print Activity Analysis Information
|
||||||
|
`PrintPerf` - print Performance Warnings from Enzyme
|
||||||
|
`Print` - prints all intermediate transformations
|
||||||
|
`PrintModBefore` - print the whole module, before running opts
|
||||||
|
`PrintModAfterOpts` - print the whole module just before we pass it to Enzyme
|
||||||
|
`PrintModAfterEnzyme` - print the module after Enzyme differentiated everything
|
||||||
|
`LooseTypes` - Enzyme's loose type debug helper (can cause incorrect gradients)
|
||||||
|
`Inline` - runs Enzyme specific Inlining
|
||||||
|
`NoModOptAfter` - do not optimize the module after Enzyme is done
|
||||||
|
`EnableFncOpt` - tell Enzyme to run LLVM Opts on each function it generated
|
||||||
|
`NoVecUnroll` - do not unroll vectorized loops
|
||||||
|
`RuntimeActivity` - allow specifying activity at runtime
|
|
@ -1 +1 @@
|
||||||
Subproject commit 2fe5164a2423dd67ef25e2c4fb204fd06362494b
|
Subproject commit 0e5fa4a3d475f4dece489c9e06b11164f83789f5
|
Loading…
Add table
Add a link
Reference in a new issue