improve cold_path()

This commit is contained in:
Jiri Bobek 2024-12-04 12:30:42 +01:00
parent c705b7d6f7
commit 7bb5f4dd78
6 changed files with 230 additions and 15 deletions

View file

@ -4,7 +4,7 @@ use std::{iter, ptr};
pub(crate) mod autodiff;
use libc::{c_char, c_uint};
use libc::{c_char, c_uint, size_t};
use rustc_abi as abi;
use rustc_abi::{Align, Size, WrappingRange};
use rustc_codegen_ssa::MemFlags;
@ -32,7 +32,7 @@ use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::common::Funclet;
use crate::context::{CodegenCx, SimpleCx};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
use crate::value::Value;
@ -333,6 +333,50 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}
fn switch_with_weights(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
else_is_cold: bool,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
) {
if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
self.switch(v, else_llbb, cases.map(|(val, dest, _)| (val, dest)));
return;
}
let id_str = "branch_weights";
let id = unsafe {
llvm::LLVMMDStringInContext2(self.cx.llcx, id_str.as_ptr().cast(), id_str.len())
};
// For switch instructions with 2 targets, the `llvm.expect` intrinsic is used.
// This function handles switch instructions with more than 2 targets and it needs to
// emit branch weights metadata instead of using the intrinsic.
// The values 1 and 2000 are the same as the values used by the `llvm.expect` intrinsic.
let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) };
let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) };
let weight =
|is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } };
let mut md: SmallVec<[&Metadata; 16]> = SmallVec::with_capacity(cases.len() + 2);
md.push(id);
md.push(weight(else_is_cold));
let switch =
unsafe { llvm::LLVMBuildSwitch(self.llbuilder, v, else_llbb, cases.len() as c_uint) };
for (on_val, dest, is_cold) in cases {
let on_val = self.const_uint_big(self.val_ty(v), on_val);
unsafe { llvm::LLVMAddCase(switch, on_val, dest) }
md.push(weight(is_cold));
}
unsafe {
let md_node = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len() as size_t);
self.cx.set_metadata(switch, llvm::MD_prof, md_node);
}
}
fn invoke(
&mut self,
llty: &'ll Type,

View file

@ -429,11 +429,34 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
bx.cond_br(cmp, ll1, ll2);
} else {
let otherwise = targets.otherwise();
let otherwise_cold = self.cold_blocks[otherwise];
let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
let none_cold = cold_count == 0;
let all_cold = cold_count == targets.iter().len();
if (none_cold && (!otherwise_cold || otherwise_unreachable))
|| (all_cold && (otherwise_cold || otherwise_unreachable))
{
// All targets have the same weight,
// or `otherwise` is unreachable and it's the only target with a different weight.
bx.switch(
discr_value,
helper.llbb_with_cleanup(self, targets.otherwise()),
target_iter.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
target_iter
.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
);
} else {
// Targets have different weights
bx.switch_with_weights(
discr_value,
helper.llbb_with_cleanup(self, targets.otherwise()),
otherwise_cold,
target_iter.map(|(value, target)| {
(value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
}),
);
}
}
}

View file

@ -502,16 +502,27 @@ fn find_cold_blocks<'tcx>(
for (bb, bb_data) in traversal::postorder(mir) {
let terminator = bb_data.terminator();
match terminator.kind {
// If a BB ends with a call to a cold function, mark it as cold.
if let mir::TerminatorKind::Call { ref func, .. } = terminator.kind
&& let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
mir::TerminatorKind::Call { ref func, .. }
| mir::TerminatorKind::TailCall { ref func, .. }
if let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
&& let attrs = tcx.codegen_fn_attrs(def_id)
&& attrs.flags.contains(CodegenFnAttrFlags::COLD)
&& attrs.flags.contains(CodegenFnAttrFlags::COLD) =>
{
cold_blocks[bb] = true;
continue;
}
// If a BB ends with an `unreachable`, also mark it as cold.
mir::TerminatorKind::Unreachable => {
cold_blocks[bb] = true;
continue;
}
_ => {}
}
// If all successors of a BB are cold and there's at least one of them, mark this BB as cold
let mut succ = terminator.successors();
if let Some(first) = succ.next()

View file

@ -110,6 +110,20 @@ pub trait BuilderMethods<'a, 'tcx>:
else_llbb: Self::BasicBlock,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock)>,
);
// This is like `switch()`, but every case has a bool flag indicating whether it's cold.
//
// Default implementation throws away the cold flags and calls `switch()`.
fn switch_with_weights(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
_else_is_cold: bool,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
) {
self.switch(v, else_llbb, cases.map(|(val, bb, _)| (val, bb)))
}
fn invoke(
&mut self,
llty: Self::Type,

View file

@ -0,0 +1,36 @@
//@ compile-flags: -O
#![crate_type = "lib"]
#![feature(core_intrinsics)]
use std::intrinsics::cold_path;
#[inline(never)]
#[no_mangle]
pub fn path_a() {
println!("path a");
}
#[inline(never)]
#[no_mangle]
pub fn path_b() {
println!("path b");
}
#[no_mangle]
pub fn test(x: Option<bool>) {
if let Some(_) = x {
path_a();
} else {
cold_path();
path_b();
}
// CHECK-LABEL: @test(
// CHECK: br i1 %1, label %bb2, label %bb1, !prof ![[NUM:[0-9]+]]
// CHECK: bb1:
// CHECK: path_a
// CHECK: bb2:
// CHECK: path_b
}
// CHECK: ![[NUM]] = !{!"branch_weights", {{(!"expected", )?}}i32 1, i32 2000}

View file

@ -0,0 +1,87 @@
//@ compile-flags: -O
#![crate_type = "lib"]
#![feature(core_intrinsics)]
use std::intrinsics::cold_path;
#[inline(never)]
#[no_mangle]
pub fn path_a() {
println!("path a");
}
#[inline(never)]
#[no_mangle]
pub fn path_b() {
println!("path b");
}
#[inline(never)]
#[no_mangle]
pub fn path_c() {
println!("path c");
}
#[inline(never)]
#[no_mangle]
pub fn path_d() {
println!("path d");
}
#[no_mangle]
pub fn test(x: Option<u32>) {
match x {
Some(0) => path_a(),
Some(1) => {
cold_path();
path_b()
}
Some(2) => path_c(),
Some(3) => {
cold_path();
path_d()
}
_ => path_a(),
}
// CHECK-LABEL: @test(
// CHECK: switch i32 %1, label %bb1 [
// CHECK: i32 0, label %bb6
// CHECK: i32 1, label %bb5
// CHECK: i32 2, label %bb4
// CHECK: i32 3, label %bb3
// CHECK: ], !prof ![[NUM1:[0-9]+]]
}
#[no_mangle]
pub fn test2(x: Option<u32>) {
match x {
Some(10) => path_a(),
Some(11) => {
cold_path();
path_b()
}
Some(12) => {
unsafe { core::intrinsics::unreachable() };
path_c()
}
Some(13) => {
cold_path();
path_d()
}
_ => {
cold_path();
path_a()
}
}
// CHECK-LABEL: @test2(
// CHECK: switch i32 %1, label %bb1 [
// CHECK: i32 10, label %bb5
// CHECK: i32 11, label %bb4
// CHECK: i32 13, label %bb3
// CHECK: ], !prof ![[NUM2:[0-9]+]]
}
// CHECK: ![[NUM1]] = !{!"branch_weights", i32 2000, i32 2000, i32 1, i32 2000, i32 1}
// CHECK: ![[NUM2]] = !{!"branch_weights", i32 1, i32 2000, i32 1, i32 1}