improve cold_path()

This commit is contained in:
Jiri Bobek 2024-12-04 12:30:42 +01:00
parent c705b7d6f7
commit 7bb5f4dd78
6 changed files with 230 additions and 15 deletions

View file

@ -4,7 +4,7 @@ use std::{iter, ptr};
pub(crate) mod autodiff;
use libc::{c_char, c_uint};
use libc::{c_char, c_uint, size_t};
use rustc_abi as abi;
use rustc_abi::{Align, Size, WrappingRange};
use rustc_codegen_ssa::MemFlags;
@ -32,7 +32,7 @@ use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::common::Funclet;
use crate::context::{CodegenCx, SimpleCx};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
use crate::value::Value;
@ -333,6 +333,50 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}
fn switch_with_weights(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
else_is_cold: bool,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
) {
if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
self.switch(v, else_llbb, cases.map(|(val, dest, _)| (val, dest)));
return;
}
let id_str = "branch_weights";
let id = unsafe {
llvm::LLVMMDStringInContext2(self.cx.llcx, id_str.as_ptr().cast(), id_str.len())
};
// For switch instructions with 2 targets, the `llvm.expect` intrinsic is used.
// This function handles switch instructions with more than 2 targets and it needs to
// emit branch weights metadata instead of using the intrinsic.
// The values 1 and 2000 are the same as the values used by the `llvm.expect` intrinsic.
let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) };
let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) };
let weight =
|is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } };
let mut md: SmallVec<[&Metadata; 16]> = SmallVec::with_capacity(cases.len() + 2);
md.push(id);
md.push(weight(else_is_cold));
let switch =
unsafe { llvm::LLVMBuildSwitch(self.llbuilder, v, else_llbb, cases.len() as c_uint) };
for (on_val, dest, is_cold) in cases {
let on_val = self.const_uint_big(self.val_ty(v), on_val);
unsafe { llvm::LLVMAddCase(switch, on_val, dest) }
md.push(weight(is_cold));
}
unsafe {
let md_node = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len() as size_t);
self.cx.set_metadata(switch, llvm::MD_prof, md_node);
}
}
fn invoke(
&mut self,
llty: &'ll Type,