Make CodegenCx and Builder generic
Co-authored-by: Oli Scherer <github35764891676564198441@oli-obk.de>
This commit is contained in:
parent
a48e7b0057
commit
386c233858
10 changed files with 239 additions and 56 deletions
|
@ -3,20 +3,19 @@ use std::ptr;
|
|||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
|
||||
use rustc_codegen_ssa::ModuleCodegen;
|
||||
use rustc_codegen_ssa::back::write::ModuleConfig;
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_errors::FatalError;
|
||||
use rustc_middle::ty::TyCtxt;
|
||||
use rustc_session::config::Lto;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::back::write::{llvm_err, llvm_optimize};
|
||||
use crate::builder::Builder;
|
||||
use crate::declare::declare_raw_fn;
|
||||
use crate::builder::SBuilder;
|
||||
use crate::context::SimpleCx;
|
||||
use crate::declare::declare_simple_fn;
|
||||
use crate::errors::LlvmError;
|
||||
use crate::llvm::AttributePlace::Function;
|
||||
use crate::llvm::{Metadata, True};
|
||||
use crate::value::Value;
|
||||
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
|
||||
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
|
||||
|
||||
fn get_params(fnc: &Value) -> Vec<&Value> {
|
||||
unsafe {
|
||||
|
@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
|
|||
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
|
||||
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
|
||||
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
|
||||
fn generate_enzyme_call<'ll, 'tcx>(
|
||||
cx: &context::CodegenCx<'ll, 'tcx>,
|
||||
fn generate_enzyme_call<'ll>(
|
||||
cx: &SimpleCx<'ll>,
|
||||
fn_to_diff: &'ll Value,
|
||||
outer_fn: &'ll Value,
|
||||
attrs: AutoDiffAttrs,
|
||||
|
@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
|
|||
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
|
||||
// think a bit more about what should go here.
|
||||
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
|
||||
let ad_fn = declare_raw_fn(
|
||||
let ad_fn = declare_simple_fn(
|
||||
cx,
|
||||
&ad_name,
|
||||
llvm::CallConv::try_from(cc).expect("invalid callconv"),
|
||||
|
@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
|
|||
llvm::LLVMRustEraseInstFromParent(br);
|
||||
|
||||
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
|
||||
let mut builder = Builder::build(cx, entry);
|
||||
let mut builder = SBuilder::build(cx, entry);
|
||||
|
||||
let num_args = llvm::LLVMCountParams(&fn_to_diff);
|
||||
let mut args = Vec::with_capacity(num_args as usize + 1);
|
||||
|
@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
|
|||
}
|
||||
}
|
||||
|
||||
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
|
||||
let call = builder.call(enzyme_ty, ad_fn, &args, None);
|
||||
|
||||
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
|
||||
// metadata attachted to it, but we just created this code oota. Given that the
|
||||
|
@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn differentiate<'ll, 'tcx>(
|
||||
pub(crate) fn differentiate<'ll>(
|
||||
module: &'ll ModuleCodegen<ModuleLlvm>,
|
||||
cgcx: &CodegenContext<LlvmCodegenBackend>,
|
||||
tcx: TyCtxt<'tcx>,
|
||||
diff_items: Vec<AutoDiffItem>,
|
||||
config: &ModuleConfig,
|
||||
) -> Result<(), FatalError> {
|
||||
|
@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
|
|||
}
|
||||
|
||||
let diag_handler = cgcx.create_dcx();
|
||||
let (_, cgus) = tcx.collect_and_partition_mono_items(());
|
||||
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
|
||||
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
|
||||
|
||||
// Before dumping the module, we want all the TypeTrees to become part of the module.
|
||||
for item in diff_items.iter() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue