Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
This commit is contained in:
parent
52fd998399
commit
624c071b99
17 changed files with 1384 additions and 1 deletions
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
|
@ -0,0 +1,283 @@
|
|||
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
|
||||
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
|
||||
//! is the function to which the autodiff attribute is applied, and the target is the function
|
||||
//! getting generated by us (with a name given by the user as the first autodiff arg).
|
||||
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::expand::typetree::TypeTree;
|
||||
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||
use crate::ptr::P;
|
||||
use crate::{Ty, TyKind};
|
||||
|
||||
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
|
||||
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
|
||||
/// are a hack to support higher order derivatives. We need to compute first order derivatives
|
||||
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
|
||||
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
|
||||
/// as it's already done in the C++ and Julia frontend of Enzyme.
|
||||
///
|
||||
/// (FIXME) remove *First variants.
|
||||
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
|
||||
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
|
||||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub enum DiffMode {
|
||||
/// No autodiff is applied (used during error handling).
|
||||
Error,
|
||||
/// The primal function which we will differentiate.
|
||||
Source,
|
||||
/// The target function, to be created using forward mode AD.
|
||||
Forward,
|
||||
/// The target function, to be created using reverse mode AD.
|
||||
Reverse,
|
||||
/// The target function, to be created using forward mode AD.
|
||||
/// This target function will also be used as a source for higher order derivatives,
|
||||
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||
ForwardFirst,
|
||||
/// The target function, to be created using reverse mode AD.
|
||||
/// This target function will also be used as a source for higher order derivatives,
|
||||
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||
ReverseFirst,
|
||||
}
|
||||
|
||||
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
|
||||
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
|
||||
/// we add to the previous shadow value. To not surprise users, we picked different names.
|
||||
/// Dual numbers is also a quite well known name for forward mode AD types.
|
||||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub enum DiffActivity {
|
||||
/// Implicit or Explicit () return type, so a special case of Const.
|
||||
None,
|
||||
/// Don't compute derivatives with respect to this input/output.
|
||||
Const,
|
||||
/// Reverse Mode, Compute derivatives for this scalar input/output.
|
||||
Active,
|
||||
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
|
||||
/// the original return value.
|
||||
ActiveOnly,
|
||||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||
/// with it.
|
||||
Dual,
|
||||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||
/// with it. Drop the code which updates the original input/output for maximum performance.
|
||||
DualOnly,
|
||||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||
Duplicated,
|
||||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||
/// Drop the code which updates the original input for maximum performance.
|
||||
DuplicatedOnly,
|
||||
/// All Integers must be Const, but these are used to mark the integer which represents the
|
||||
/// length of a slice/vec. This is used for safety checks on slices.
|
||||
FakeActivitySize,
|
||||
}
|
||||
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub struct AutoDiffItem {
|
||||
/// The name of the function getting differentiated
|
||||
pub source: String,
|
||||
/// The name of the function being generated
|
||||
pub target: String,
|
||||
pub attrs: AutoDiffAttrs,
|
||||
/// Describe the memory layout of input types
|
||||
pub inputs: Vec<TypeTree>,
|
||||
/// Describe the memory layout of the output type
|
||||
pub output: TypeTree,
|
||||
}
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub struct AutoDiffAttrs {
|
||||
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
|
||||
/// e.g. in the [JAX
|
||||
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
||||
pub mode: DiffMode,
|
||||
pub ret_activity: DiffActivity,
|
||||
pub input_activity: Vec<DiffActivity>,
|
||||
}
|
||||
|
||||
impl DiffMode {
|
||||
pub fn is_rev(&self) -> bool {
|
||||
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
|
||||
}
|
||||
pub fn is_fwd(&self) -> bool {
|
||||
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for DiffMode {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
DiffMode::Error => write!(f, "Error"),
|
||||
DiffMode::Source => write!(f, "Source"),
|
||||
DiffMode::Forward => write!(f, "Forward"),
|
||||
DiffMode::Reverse => write!(f, "Reverse"),
|
||||
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
|
||||
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
|
||||
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
|
||||
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
|
||||
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
|
||||
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
|
||||
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||
if activity == DiffActivity::None {
|
||||
// Only valid if primal returns (), but we can't check that here.
|
||||
return true;
|
||||
}
|
||||
match mode {
|
||||
DiffMode::Error => false,
|
||||
DiffMode::Source => false,
|
||||
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||
activity == DiffActivity::Dual
|
||||
|| activity == DiffActivity::DualOnly
|
||||
|| activity == DiffActivity::Const
|
||||
}
|
||||
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||
activity == DiffActivity::Const
|
||||
|| activity == DiffActivity::Active
|
||||
|| activity == DiffActivity::ActiveOnly
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
|
||||
/// for the given argument, but we generally can't know the size of such a type.
|
||||
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
|
||||
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
|
||||
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
|
||||
/// users here from marking scalars as Duplicated, due to type aliases.
|
||||
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
|
||||
use DiffActivity::*;
|
||||
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
|
||||
if matches!(activity, Const) {
|
||||
return true;
|
||||
}
|
||||
if matches!(activity, Dual | DualOnly) {
|
||||
return true;
|
||||
}
|
||||
// FIXME(ZuseZ4) We should make this more robust to also
|
||||
// handle type aliases. Once that is done, we can be more restrictive here.
|
||||
if matches!(activity, Active | ActiveOnly) {
|
||||
return true;
|
||||
}
|
||||
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
|
||||
&& matches!(activity, Duplicated | DuplicatedOnly)
|
||||
}
|
||||
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||
use DiffActivity::*;
|
||||
return match mode {
|
||||
DiffMode::Error => false,
|
||||
DiffMode::Source => false,
|
||||
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||
matches!(activity, Dual | DualOnly | Const)
|
||||
}
|
||||
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Display for DiffActivity {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DiffActivity::None => write!(f, "None"),
|
||||
DiffActivity::Const => write!(f, "Const"),
|
||||
DiffActivity::Active => write!(f, "Active"),
|
||||
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
|
||||
DiffActivity::Dual => write!(f, "Dual"),
|
||||
DiffActivity::DualOnly => write!(f, "DualOnly"),
|
||||
DiffActivity::Duplicated => write!(f, "Duplicated"),
|
||||
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
|
||||
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for DiffMode {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<DiffMode, ()> {
|
||||
match s {
|
||||
"Error" => Ok(DiffMode::Error),
|
||||
"Source" => Ok(DiffMode::Source),
|
||||
"Forward" => Ok(DiffMode::Forward),
|
||||
"Reverse" => Ok(DiffMode::Reverse),
|
||||
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
|
||||
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromStr for DiffActivity {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<DiffActivity, ()> {
|
||||
match s {
|
||||
"None" => Ok(DiffActivity::None),
|
||||
"Active" => Ok(DiffActivity::Active),
|
||||
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
|
||||
"Const" => Ok(DiffActivity::Const),
|
||||
"Dual" => Ok(DiffActivity::Dual),
|
||||
"DualOnly" => Ok(DiffActivity::DualOnly),
|
||||
"Duplicated" => Ok(DiffActivity::Duplicated),
|
||||
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoDiffAttrs {
|
||||
pub fn has_ret_activity(&self) -> bool {
|
||||
self.ret_activity != DiffActivity::None
|
||||
}
|
||||
pub fn has_active_only_ret(&self) -> bool {
|
||||
self.ret_activity == DiffActivity::ActiveOnly
|
||||
}
|
||||
|
||||
pub fn error() -> Self {
|
||||
AutoDiffAttrs {
|
||||
mode: DiffMode::Error,
|
||||
ret_activity: DiffActivity::None,
|
||||
input_activity: Vec::new(),
|
||||
}
|
||||
}
|
||||
pub fn source() -> Self {
|
||||
AutoDiffAttrs {
|
||||
mode: DiffMode::Source,
|
||||
ret_activity: DiffActivity::None,
|
||||
input_activity: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.mode != DiffMode::Error
|
||||
}
|
||||
|
||||
pub fn is_source(&self) -> bool {
|
||||
self.mode == DiffMode::Source
|
||||
}
|
||||
pub fn apply_autodiff(&self) -> bool {
|
||||
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
||||
}
|
||||
|
||||
pub fn into_item(
|
||||
self,
|
||||
source: String,
|
||||
target: String,
|
||||
inputs: Vec<TypeTree>,
|
||||
output: TypeTree,
|
||||
) -> AutoDiffItem {
|
||||
AutoDiffItem { source, target, inputs, output, attrs: self }
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AutoDiffItem {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
||||
write!(f, " with attributes: {:?}", self.attrs)?;
|
||||
write!(f, " with inputs: {:?}", self.inputs)?;
|
||||
write!(f, " with output: {:?}", self.output)
|
||||
}
|
||||
}
|
|
@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
|
|||
use crate::MetaItem;
|
||||
|
||||
pub mod allocator;
|
||||
pub mod autodiff_attrs;
|
||||
pub mod typetree;
|
||||
|
||||
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
|
||||
pub struct StrippedCfgItem<ModId = DefId> {
|
||||
|
|
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
|
@ -0,0 +1,90 @@
|
|||
//! This module contains the definition of the `TypeTree` and `Type` structs.
|
||||
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
|
||||
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
|
||||
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
|
||||
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
|
||||
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
|
||||
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
|
||||
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
|
||||
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
|
||||
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
|
||||
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
|
||||
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
|
||||
//! Generally, it allows byte-specific descriptions.
|
||||
//! FIXME: This description might be partly inaccurate and should be extended, along with
|
||||
//! adding documentation to the corresponding Enzyme core code.
|
||||
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
|
||||
//! provide typetree information.
|
||||
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
|
||||
//! representations of some types might not be accurate. For example a vector of floats might be
|
||||
//! represented as a vector of u8s in MIR in some cases.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||
|
||||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub enum Kind {
|
||||
Anything,
|
||||
Integer,
|
||||
Pointer,
|
||||
Half,
|
||||
Float,
|
||||
Double,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub struct TypeTree(pub Vec<Type>);
|
||||
|
||||
impl TypeTree {
|
||||
pub fn new() -> Self {
|
||||
Self(Vec::new())
|
||||
}
|
||||
pub fn all_ints() -> Self {
|
||||
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
|
||||
}
|
||||
pub fn int(size: usize) -> Self {
|
||||
let mut ints = Vec::with_capacity(size);
|
||||
for i in 0..size {
|
||||
ints.push(Type {
|
||||
offset: i as isize,
|
||||
size: 1,
|
||||
kind: Kind::Integer,
|
||||
child: TypeTree::new(),
|
||||
});
|
||||
}
|
||||
Self(ints)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub struct FncTree {
|
||||
pub args: Vec<TypeTree>,
|
||||
pub ret: TypeTree,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
pub struct Type {
|
||||
pub offset: isize,
|
||||
pub size: usize,
|
||||
pub kind: Kind,
|
||||
pub child: TypeTree,
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn add_offset(self, add: isize) -> Self {
|
||||
let offset = match self.offset {
|
||||
-1 => add,
|
||||
x => add + x,
|
||||
};
|
||||
|
||||
Self { size: self.size, kind: self.kind, child: self.child, offset }
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Type {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
<Self as fmt::Debug>::fmt(self, f)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue