Don't alloca just to look at a discriminant

Today we're making LLVM do a bunch of extra work for every enum you match on, even trivial stuff like `Option<bool>`.  Let's not.
This commit is contained in:
Scott McMurray 2025-03-12 00:38:14 -07:00
parent 6650252439
commit 143f39362a
9 changed files with 177 additions and 167 deletions

View file

@ -205,7 +205,12 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
| PlaceContext::MutatingUse(MutatingUseContext::Retag) => {}
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
NonMutatingUseContext::Copy
| NonMutatingUseContext::Move
// Inspect covers things like `PtrMetadata` and `Discriminant`
// which we can treat similar to `Copy` use for the purpose of
// whether we can use SSA variables for things.
| NonMutatingUseContext::Inspect,
) => match &mut self.locals[local] {
LocalKind::ZST => {}
LocalKind::Memory => {}
@ -229,8 +234,7 @@ impl<'a, 'b, 'tcx, Bx: BuilderMethods<'b, 'tcx>> Visitor<'tcx> for LocalAnalyzer
| MutatingUseContext::Projection,
)
| PlaceContext::NonMutatingUse(
NonMutatingUseContext::Inspect
| NonMutatingUseContext::SharedBorrow
NonMutatingUseContext::SharedBorrow
| NonMutatingUseContext::FakeBorrow
| NonMutatingUseContext::RawBorrow
| NonMutatingUseContext::Projection,

View file

@ -62,7 +62,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let callee_ty = instance.ty(bx.tcx(), bx.typing_env());
let ty::FnDef(def_id, fn_args) = *callee_ty.kind() else {
bug!("expected fn item type, found {}", callee_ty);
span_bug!(span, "expected fn item type, found {}", callee_ty);
};
let sig = callee_ty.fn_sig(bx.tcx());
@ -325,14 +325,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}
sym::discriminant_value => {
if ret_ty.is_integral() {
args[0].deref(bx.cx()).codegen_get_discr(bx, ret_ty)
} else {
span_bug!(span, "Invalid discriminant type for `{:?}`", arg_tys[0])
}
}
// This requires that atomic intrinsics follow a specific naming pattern:
// "atomic_<operation>[_<ordering>]"
name if let Some(atomic) = name_str.strip_prefix("atomic_") => {

View file

@ -3,16 +3,17 @@ use std::fmt;
use arrayvec::ArrayVec;
use either::Either;
use rustc_abi as abi;
use rustc_abi::{Align, BackendRepr, Size};
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
use rustc_middle::mir::{self, ConstValue};
use rustc_middle::ty::Ty;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::{bug, span_bug};
use tracing::debug;
use tracing::{debug, instrument};
use super::place::{PlaceRef, PlaceValue};
use super::{FunctionCx, LocalRef};
use crate::common::IntPredicate;
use crate::traits::*;
use crate::{MemFlags, size_of_val};
@ -415,6 +416,140 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
OperandRef { val, layout: field }
}
/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(fx, bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
fx: &mut FunctionCx<'a, 'tcx, Bx>,
bx: &mut Bx,
cast_to: Ty<'tcx>,
) -> V {
let dl = &bx.tcx().data_layout;
let cast_to_layout = bx.cx().layout_of(cast_to);
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
if self.layout.is_uninhabited() {
return bx.cx().const_poison(cast_to);
}
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
Variants::Empty => unreachable!("we already handled uninhabited types"),
Variants::Single { index } => {
let discr_val =
if let Some(discr) = self.layout.ty.discriminant_for_variant(bx.tcx(), index) {
discr.val
} else {
assert_eq!(index, FIRST_VARIANT);
0
};
return bx.cx().const_uint_big(cast_to, discr_val);
}
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
(tag, tag_encoding, tag_field)
}
};
// Read the tag/niche-encoded discriminant from memory.
let tag_op = match self.val {
OperandValue::ZeroSized => bug!(),
OperandValue::Immediate(_) | OperandValue::Pair(_, _) => {
self.extract_field(fx, bx, tag_field)
}
OperandValue::Ref(place) => {
let tag = place.with_type(self.layout).project_field(bx, tag_field);
bx.load_operand(tag)
}
};
let tag_imm = tag_op.immediate();
// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
Primitive::Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag_imm, cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Cast to an integer so we don't have to treat a pointer as a
// special case.
let (tag, tag_llty) = match tag_scalar.primitive() {
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
Primitive::Pointer(_) => {
let t = bx.type_from_integer(dl.ptr_sized_integer());
let tag = bx.ptrtoint(tag_imm, t);
(tag, t)
}
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
};
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
let tagged_discr = if delta == 0 {
tagged_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};
let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);
// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this seems to be a pessimization.
discr
}
}
}
}
impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {

View file

@ -1,4 +1,3 @@
use rustc_abi::Primitive::{Int, Pointer};
use rustc_abi::{Align, BackendRepr, FieldsShape, Size, TagEncoding, VariantIdx, Variants};
use rustc_middle::mir::PlaceTy;
use rustc_middle::mir::interpret::Scalar;
@ -233,129 +232,6 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
val.with_type(field)
}
/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
bx: &mut Bx,
cast_to: Ty<'tcx>,
) -> V {
let dl = &bx.tcx().data_layout;
let cast_to_layout = bx.cx().layout_of(cast_to);
let cast_to = bx.cx().immediate_backend_type(cast_to_layout);
if self.layout.is_uninhabited() {
return bx.cx().const_poison(cast_to);
}
let (tag_scalar, tag_encoding, tag_field) = match self.layout.variants {
Variants::Empty => unreachable!("we already handled uninhabited types"),
Variants::Single { index } => {
let discr_val = self
.layout
.ty
.discriminant_for_variant(bx.cx().tcx(), index)
.map_or(index.as_u32() as u128, |discr| discr.val);
return bx.cx().const_uint_big(cast_to, discr_val);
}
Variants::Multiple { tag, ref tag_encoding, tag_field, .. } => {
(tag, tag_encoding, tag_field)
}
};
// Read the tag/niche-encoded discriminant from memory.
let tag = self.project_field(bx, tag_field);
let tag_op = bx.load_operand(tag);
let tag_imm = tag_op.immediate();
// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
Int(_, signed) => !tag_scalar.is_bool() && signed,
_ => false,
};
bx.intcast(tag_imm, cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Cast to an integer so we don't have to treat a pointer as a
// special case.
let (tag, tag_llty) = match tag_scalar.primitive() {
// FIXME(erikdesjardins): handle non-default addrspace ptr sizes
Pointer(_) => {
let t = bx.type_from_integer(dl.ptr_sized_integer());
let tag = bx.ptrtoint(tag_imm, t);
(tag, t)
}
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
};
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
let tagged_discr = if delta == 0 {
tagged_discr
} else {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};
let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);
// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this seems to be a pessimization.
discr
}
}
}
/// Sets the discriminant for a new value of the given case of the given
/// representation.
pub fn codegen_set_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(

View file

@ -706,7 +706,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
mir::Rvalue::Discriminant(ref place) => {
let discr_ty = rvalue.ty(self.mir, bx.tcx());
let discr_ty = self.monomorphize(discr_ty);
let discr = self.codegen_place(bx, place.as_ref()).codegen_get_discr(bx, discr_ty);
let operand = self.codegen_consume(bx, place.as_ref());
let discr = operand.codegen_get_discr(self, bx, discr_ty);
OperandRef {
val: OperandValue::Immediate(discr),
layout: self.cx.layout_of(discr_ty),

View file

@ -7,19 +7,22 @@
// CHECK-LABEL: @option_match
#[no_mangle]
pub fn option_match(x: Option<i32>) -> u16 {
// CHECK: %x = alloca [8 x i8]
// CHECK: store i32 %0, ptr %x
// CHECK: %[[TAG:.+]] = load i32, ptr %x
// CHECK-SAME: !range ![[ZERO_ONE_32:[0-9]+]]
// CHECK: %[[DISCR:.+]] = zext i32 %[[TAG]] to i64
// CHECK-NOT: %x = alloca
// CHECK: %[[OUT:.+]] = alloca [2 x i8]
// CHECK-NOT: %x = alloca
// CHECK: %[[DISCR:.+]] = zext i32 %x.0 to i64
// CHECK: %[[COND:.+]] = trunc nuw i64 %[[DISCR]] to i1
// CHECK: br i1 %[[COND]], label %[[TRUE:[a-z0-9]+]], label %[[FALSE:[a-z0-9]+]]
// CHECK: [[TRUE]]:
// CHECK: store i16 13
// CHECK: store i16 13, ptr %[[OUT]]
// CHECK: [[FALSE]]:
// CHECK: store i16 42
// CHECK: store i16 42, ptr %[[OUT]]
// CHECK: %[[RET:.+]] = load i16, ptr %[[OUT]]
// CHECK: ret i16 %[[RET]]
match x {
Some(_) => 13,
None => 42,
@ -29,23 +32,23 @@ pub fn option_match(x: Option<i32>) -> u16 {
// CHECK-LABEL: @result_match
#[no_mangle]
pub fn result_match(x: Result<u64, i64>) -> u16 {
// CHECK: %x = alloca [16 x i8]
// CHECK: store i64 %0, ptr %x
// CHECK: %[[DISCR:.+]] = load i64, ptr %x
// CHECK-SAME: !range ![[ZERO_ONE_64:[0-9]+]]
// CHECK: %[[COND:.+]] = trunc nuw i64 %[[DISCR]] to i1
// CHECK-NOT: %x = alloca
// CHECK: %[[OUT:.+]] = alloca [2 x i8]
// CHECK-NOT: %x = alloca
// CHECK: %[[COND:.+]] = trunc nuw i64 %x.0 to i1
// CHECK: br i1 %[[COND]], label %[[TRUE:[a-z0-9]+]], label %[[FALSE:[a-z0-9]+]]
// CHECK: [[TRUE]]:
// CHECK: store i16 13
// CHECK: store i16 13, ptr %[[OUT]]
// CHECK: [[FALSE]]:
// CHECK: store i16 42
// CHECK: store i16 42, ptr %[[OUT]]
// CHECK: %[[RET:.+]] = load i16, ptr %[[OUT]]
// CHECK: ret i16 %[[RET]]
match x {
Err(_) => 13,
Ok(_) => 42,
}
}
// CHECK: ![[ZERO_ONE_32]] = !{i32 0, i32 2}
// CHECK: ![[ZERO_ONE_64]] = !{i64 0, i64 2}

View file

@ -25,8 +25,8 @@ pub fn test(x: Option<bool>) {
path_b();
}
// CHECK-LABEL: @test(
// CHECK: %[[IS_NONE:.+]] = icmp eq i8 %0, 2
// CHECK-LABEL: void @test(i8{{.+}}%x)
// CHECK: %[[IS_NONE:.+]] = icmp eq i8 %x, 2
// CHECK: br i1 %[[IS_NONE]], label %bb2, label %bb1, !prof ![[NUM:[0-9]+]]
// CHECK: bb1:
// CHECK: path_a

View file

@ -1,5 +1,4 @@
//
//@ compile-flags: -Copt-level=3
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled
#![crate_type = "lib"]
pub enum Three {
@ -18,9 +17,9 @@ pub enum Four {
#[no_mangle]
pub fn three_valued(x: Three) -> Three {
// CHECK-LABEL: @three_valued
// CHECK-LABEL: i8 @three_valued(i8{{.+}}%x)
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i8 %0
// CHECK-NEXT: ret i8 %x
match x {
Three::A => Three::A,
Three::B => Three::B,
@ -30,9 +29,9 @@ pub fn three_valued(x: Three) -> Three {
#[no_mangle]
pub fn four_valued(x: Four) -> Four {
// CHECK-LABEL: @four_valued
// CHECK-LABEL: i16 @four_valued(i16{{.+}}%x)
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i16 %0
// CHECK-NEXT: ret i16 %x
match x {
Four::A => Four::A,
Four::B => Four::B,

View file

@ -17,10 +17,10 @@ use std::ptr::NonNull;
pub fn option_nop_match_32(x: Option<u32>) -> Option<u32> {
// CHECK: start:
// TWENTY-NEXT: %[[IS_SOME:.+]] = trunc nuw i32 %0 to i1
// TWENTY-NEXT: %.2 = select i1 %[[IS_SOME]], i32 %1, i32 undef
// TWENTY-NEXT: %[[PAYLOAD:.+]] = select i1 %[[IS_SOME]], i32 %1, i32 undef
// CHECK-NEXT: [[REG1:%.*]] = insertvalue { i32, i32 } poison, i32 %0, 0
// NINETEEN-NEXT: [[REG2:%.*]] = insertvalue { i32, i32 } [[REG1]], i32 %1, 1
// TWENTY-NEXT: [[REG2:%.*]] = insertvalue { i32, i32 } [[REG1]], i32 %.2, 1
// TWENTY-NEXT: [[REG2:%.*]] = insertvalue { i32, i32 } [[REG1]], i32 %[[PAYLOAD]], 1
// CHECK-NEXT: ret { i32, i32 } [[REG2]]
match x {
Some(x) => Some(x),
@ -33,7 +33,7 @@ pub fn option_nop_match_32(x: Option<u32>) -> Option<u32> {
pub fn option_nop_traits_32(x: Option<u32>) -> Option<u32> {
// CHECK: start:
// TWENTY-NEXT: %[[IS_SOME:.+]] = trunc nuw i32 %0 to i1
// TWENTY-NEXT: %.1 = select i1 %[[IS_SOME]], i32 %1, i32 undef
// TWENTY-NEXT: select i1 %[[IS_SOME]], i32 %1, i32 undef
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: ret { i32, i32 }