Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk
Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle This PR should not be merged until the rustc_codegen_llvm part is merged. I will also alter it a little based on what get's shaved off from the cg_llvm PR, and address some of the feedback I received in the other PR (including cleanups). I am putting it already up to 1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and 2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost. Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please? For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: https://github.com/EnzymeAD/rust/issues/173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me. Tracking: - https://github.com/rust-lang/rust/issues/124509 r? `@jieyouxu`
This commit is contained in:
commit
c19c4b91f5
27 changed files with 482 additions and 38 deletions
|
@ -257,7 +257,7 @@ struct SharedState<'tcx> {
|
|||
|
||||
pub(crate) struct UsageMap<'tcx> {
|
||||
// Maps every mono item to the mono items used by it.
|
||||
used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
|
||||
// Maps every mono item to the mono items that use it.
|
||||
user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
|
||||
|
|
|
@ -92,6 +92,8 @@
|
|||
//! source-level module, functions from the same module will be available for
|
||||
//! inlining, even when they are not marked `#[inline]`.
|
||||
|
||||
mod autodiff;
|
||||
|
||||
use std::cmp;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::fs::{self, File};
|
||||
|
@ -251,7 +253,17 @@ where
|
|||
can_export_generics,
|
||||
always_export_generics,
|
||||
);
|
||||
if visibility == Visibility::Hidden && can_be_internalized {
|
||||
|
||||
// We can't differentiate something that got inlined.
|
||||
let autodiff_active = cfg!(llvm_enzyme)
|
||||
&& cx
|
||||
.tcx
|
||||
.codegen_fn_attrs(mono_item.def_id())
|
||||
.autodiff_item
|
||||
.as_ref()
|
||||
.is_some_and(|ad| ad.is_active());
|
||||
|
||||
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
|
||||
internalization_candidates.insert(mono_item);
|
||||
}
|
||||
let size_estimate = mono_item.size_estimate(cx.tcx);
|
||||
|
@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
|||
})
|
||||
.collect();
|
||||
|
||||
let autodiff_mono_items: Vec<_> = items
|
||||
.iter()
|
||||
.filter_map(|item| match *item {
|
||||
MonoItem::Fn(ref instance) => Some((item, instance)),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let autodiff_items =
|
||||
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
|
||||
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
|
||||
|
||||
// Output monomorphization stats per def_id
|
||||
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
|
||||
if let Err(err) =
|
||||
|
@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
|
|||
}
|
||||
}
|
||||
|
||||
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
|
||||
MonoItemPartitions {
|
||||
all_mono_items: tcx.arena.alloc(mono_items),
|
||||
codegen_units,
|
||||
autodiff_items,
|
||||
}
|
||||
}
|
||||
|
||||
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
|
||||
|
|
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
121
compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
|
||||
use rustc_hir::def_id::LOCAL_CRATE;
|
||||
use rustc_middle::bug;
|
||||
use rustc_middle::mir::mono::MonoItem;
|
||||
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
|
||||
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::partitioning::UsageMap;
|
||||
|
||||
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
|
||||
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
|
||||
bug!("expected fn def for autodiff, got {:?}", fn_ty);
|
||||
}
|
||||
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
|
||||
|
||||
// If rustc compiles the unmodified primal, we know that this copy of the function
|
||||
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
|
||||
// (or actually at all), so let's strip lifetimes when computing the layout.
|
||||
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
|
||||
let mut new_activities = vec![];
|
||||
let mut new_positions = vec![];
|
||||
for (i, ty) in x.inputs().iter().enumerate() {
|
||||
if let Some(inner_ty) = ty.builtin_deref(true) {
|
||||
if ty.is_fn_ptr() {
|
||||
// FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
|
||||
// since Enzyme itself can handle them.
|
||||
tcx.dcx().err("function pointers are currently not supported in autodiff");
|
||||
}
|
||||
if inner_ty.is_slice() {
|
||||
// We know that the length will be passed as extra arg.
|
||||
if !da.is_empty() {
|
||||
// We are looking at a slice. The length of that slice will become an
|
||||
// extra integer on llvm level. Integers are always const.
|
||||
// However, if the slice get's duplicated, we want to know to later check the
|
||||
// size. So we mark the new size argument as FakeActivitySize.
|
||||
let activity = match da[i] {
|
||||
DiffActivity::DualOnly
|
||||
| DiffActivity::Dual
|
||||
| DiffActivity::DuplicatedOnly
|
||||
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
|
||||
DiffActivity::Const => DiffActivity::Const,
|
||||
_ => bug!("unexpected activity for ptr/ref"),
|
||||
};
|
||||
new_activities.push(activity);
|
||||
new_positions.push(i + 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// now add the extra activities coming from slices
|
||||
// Reverse order to not invalidate the indices
|
||||
for _ in 0..new_activities.len() {
|
||||
let pos = new_positions.pop().unwrap();
|
||||
let activity = new_activities.pop().unwrap();
|
||||
da.insert(pos, activity);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn find_autodiff_source_functions<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
usage_map: &UsageMap<'tcx>,
|
||||
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
|
||||
) -> Vec<AutoDiffItem> {
|
||||
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
|
||||
for (item, instance) in autodiff_mono_items {
|
||||
let target_id = instance.def_id();
|
||||
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
|
||||
let Some(target_attrs) = cg_fn_attr else {
|
||||
continue;
|
||||
};
|
||||
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
|
||||
if target_attrs.is_source() {
|
||||
trace!("source found: {:?}", target_id);
|
||||
}
|
||||
if !target_attrs.apply_autodiff() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
|
||||
|
||||
let source =
|
||||
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
|
||||
MonoItem::Fn(ref instance_s) => {
|
||||
let source_id = instance_s.def_id();
|
||||
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
|
||||
&& ad.is_active()
|
||||
{
|
||||
return Some(instance_s);
|
||||
}
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
let inst = match source {
|
||||
Some(source) => source,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
debug!("source_id: {:?}", inst.def_id());
|
||||
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
|
||||
assert!(fn_ty.is_fn());
|
||||
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
|
||||
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
|
||||
|
||||
let mut new_target_attrs = target_attrs.clone();
|
||||
new_target_attrs.input_activity = input_activities;
|
||||
let itm = new_target_attrs.into_item(symb, target_symbol);
|
||||
autodiff_items.push(itm);
|
||||
}
|
||||
|
||||
if !autodiff_items.is_empty() {
|
||||
trace!("AUTODIFF ITEMS EXIST");
|
||||
for item in &mut *autodiff_items {
|
||||
trace!("{}", &item);
|
||||
}
|
||||
}
|
||||
|
||||
autodiff_items
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue