1
Fork 0

upstream rustc_codegen_llvm changes for enzyme/autodiff

This commit is contained in:
Manuel Drehwald 2025-01-01 21:42:45 +01:00
parent 372442fe5f
commit d753cbf779
17 changed files with 610 additions and 28 deletions

View file

@ -1,5 +1,6 @@
#include "LLVMWrapper.h"
#include "llvm-c/Analysis.h"
#include "llvm-c/Core.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@ -165,6 +166,30 @@ extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));
}
enum class LLVMRustVerifierFailureAction {
AbortProcessAction = 0,
PrintMessageAction = 1,
ReturnStatusAction = 2,
};
static LLVMVerifierFailureAction
fromRust(LLVMRustVerifierFailureAction Action) {
switch (Action) {
case LLVMRustVerifierFailureAction::AbortProcessAction:
return LLVMAbortProcessAction;
case LLVMRustVerifierFailureAction::PrintMessageAction:
return LLVMPrintMessageAction;
case LLVMRustVerifierFailureAction::ReturnStatusAction:
return LLVMReturnStatusAction;
}
report_fatal_error("Invalid LLVMVerifierFailureAction value!");
}
extern "C" LLVMBool
LLVMRustVerifyFunction(LLVMValueRef Fn, LLVMRustVerifierFailureAction Action) {
return LLVMVerifyFunction(Fn, fromRust(Action));
}
enum class LLVMRustTailCallKind {
None,
Tail,
@ -388,6 +413,17 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr,
AddAttributes(Call, Index, Attrs, AttrsLen);
}
extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) {
Instruction *ret = unwrap(BB)->getTerminator();
return wrap(ret);
}
extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) {
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
I->eraseFromParent();
}
}
extern "C" LLVMAttributeRef
LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) {
return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr)));
@ -954,6 +990,47 @@ extern "C" void LLVMRustAddModuleFlagString(
MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen)));
}
extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) {
auto Point = unwrap(BB)->rbegin();
if (Point != unwrap(BB)->rend())
return wrap(&*Point);
return nullptr;
}
extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
auto &BB = *unwrap(bb);
auto &Inst = *unwrap<Instruction>(I);
auto It = BB.begin();
while (&*It != &Inst)
++It;
// Make sure we found the Instruction.
assert(It != BB.end());
// We don't want to erase the instruction itself.
It--;
// Delete in rev order to ensure no dangling references.
while (It != BB.begin()) {
auto Prev = std::prev(It);
It->eraseFromParent();
It = Prev;
}
It->eraseFromParent();
}
extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) {
if (auto *I = dyn_cast<Instruction>(unwrap<Value>(inst))) {
return I->hasMetadata(kindID);
}
return false;
}
extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) {
if (auto *I = dyn_cast<Instruction>(unwrap<Value>(x))) {
auto *MD = I->getDebugLoc().getAsMDNode();
return wrap(MD);
}
return nullptr;
}
extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind,
LLVMMetadataRef MD) {
unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD));