upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
This commit is contained in:
parent
ebcf860e73
commit
1f30517d40
27 changed files with 482 additions and 38 deletions
|
@ -257,7 +257,7 @@ struct SharedState<'tcx> {
|
|||
|
||||
pub(crate) struct UsageMap<'tcx> {
|
||||
// Maps every mono item to the mono items used by it.
|
||||
used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
|
||||
// Maps every mono item to the mono items that use it.
|
||||
user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
|
|
|
@ -92,6 +92,8 @@
|
|||
//! source-level module, functions from the same module will be available for
|
||||
//! inlining, even when they are not marked `#[inline]`.
|
||||
|
||||
mod autodiff;
|
||||
|
||||
use std::cmp;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::fs::{self, File};
|
||||
|
@ -251,7 +253,17 @@ where
|
|||
can_export_generics,
|
||||
always_export_generics,
|
||||
);
|
||||
if visibility == Visibility::Hidden && can_be_internalized {
|
||||
|
||||
// We can't differentiate something that got inlined.
|
||||
let autodiff_active = cfg!(llvm_enzyme)
|
||||
&& cx
|
||||
.tcx
|
||||
.codegen_fn_attrs(mono_item.def_id())
|
||||
.autodiff_item
|
||||
.as_ref()
|
||||
.is_some_and(|ad| ad.is_active());
|
||||
|
||||
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
|
||||
internalization_candidates.insert(mono_item);
|
||||
}
|
||||
let size_estimate = mono_item.size_estimate(cx.tcx);
|
||||
|
@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
|||
})
|
||||
.collect();
|
||||
|
||||
let autodiff_mono_items: Vec<_> = items
|
||||
.iter()
|
||||
.filter_map(|item| match *item {
|
||||
MonoItem::Fn(ref instance) => Some((item, instance)),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let autodiff_items =
|
||||
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
|
||||
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
|
||||
|
||||
// Output monomorphization stats per def_id
|
||||
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
|
||||
if let Err(err) =
|
||||
|
@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
|||
}
|
||||
}
|
||||
|
||||
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
|
||||
MonoItemPartitions {
|
||||
all_mono_items: tcx.arena.alloc(mono_items),
|
||||
codegen_units,
|
||||
autodiff_items,
|
||||
}
|
||||
}
|
||||
|
||||
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
|
||||
|
|
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
|
||||
use rustc_hir::def_id::LOCAL_CRATE;
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::mir::mono::MonoItem;
|
||||
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
||||
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::partitioning::UsageMap;
|
||||
|
||||
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
|
||||
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||
}
|
||||
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
|
||||
|
||||
// If rustc compiles the unmodified primal, we know that this copy of the function
|
||||
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
|
||||
// (or actually at all), so let's strip lifetimes when computing the layout.
|
||||
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
|
||||
let mut new_activities = vec![];
|
||||
let mut new_positions = vec![];
|
||||
for (i, ty) in x.inputs().iter().enumerate() {
|
||||
if let Some(inner_ty) = ty.builtin_deref(true) {
|
||||
if ty.is_fn_ptr() {
|
||||
// FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
|
||||
// since Enzyme itself can handle them.
|
||||
tcx.dcx().err("function pointers are currently not supported in autodiff");
|
||||
}
|
||||
if inner_ty.is_slice() {
|
||||
// We know that the length will be passed as extra arg.
|
||||
if !da.is_empty() {
|
||||
// We are looking at a slice. The length of that slice will become an
|
||||
// extra integer on llvm level. Integers are always const.
|
||||
// However, if the slice get's duplicated, we want to know to later check the
|
||||
// size. So we mark the new size argument as FakeActivitySize.
|
||||
let activity = match da[i] {
|
||||
DiffActivity::DualOnly
|
||||
| DiffActivity::Dual
|
||||
| DiffActivity::DuplicatedOnly
|
||||
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
|
||||
DiffActivity::Const => DiffActivity::Const,
|
||||
_ => bug!("unexpected activity for ptr/ref"),
|
||||
};
|
||||
new_activities.push(activity);
|
||||
new_positions.push(i + 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// now add the extra activities coming from slices
|
||||
// Reverse order to not invalidate the indices
|
||||
for _ in 0..new_activities.len() {
|
||||
let pos = new_positions.pop().unwrap();
|
||||
let activity = new_activities.pop().unwrap();
|
||||
da.insert(pos, activity);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn find_autodiff_source_functions<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
usage_map: &UsageMap<'tcx>,
|
||||
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
|
||||
) -> Vec<AutoDiffItem> {
|
||||
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
|
||||
for (item, instance) in autodiff_mono_items {
|
||||
let target_id = instance.def_id();
|
||||
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
|
||||
let Some(target_attrs) = cg_fn_attr else {
|
||||
continue;
|
||||
};
|
||||
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
|
||||
if target_attrs.is_source() {
|
||||
trace!("source found: {:?}", target_id);
|
||||
}
|
||||
if !target_attrs.apply_autodiff() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
|
||||
|
||||
let source =
|
||||
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
|
||||
MonoItem::Fn(ref instance_s) => {
|
||||
let source_id = instance_s.def_id();
|
||||
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
|
||||
&& ad.is_active()
|
||||
{
|
||||
return Some(instance_s);
|
||||
}
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
let inst = match source {
|
||||
Some(source) => source,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
debug!("source_id: {:?}", inst.def_id());
|
||||
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
|
||||
assert!(fn_ty.is_fn());
|
||||
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
|
||||
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
|
||||
|
||||
let mut new_target_attrs = target_attrs.clone();
|
||||
new_target_attrs.input_activity = input_activities;
|
||||
let itm = new_target_attrs.into_item(symb, target_symbol);
|
||||
autodiff_items.push(itm);
|
||||
}
|
||||
|
||||
if !autodiff_items.is_empty() {
|
||||
trace!("AUTODIFF ITEMS EXIST");
|
||||
for item in &mut *autodiff_items {
|
||||
trace!("{}", &item);
|
||||
}
|
||||
}
|
||||
|
||||
autodiff_items
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue