Add tag_for_variant query

This query allows for sharing code between `rustc_const_eval` and
`rustc_transmutability`.

Also moves `DummyMachine` to `rustc_const_eval`.
This commit is contained in:
Jack Wrenn 2024-03-20 17:45:14 +00:00
parent 9023f908cf
commit 2de9010f66
13 changed files with 347 additions and 292 deletions

View file

@ -0,0 +1,193 @@
use crate::interpret::{self, HasStaticRootDefId, ImmTy, Immediate, InterpCx, PointerArithmetic};
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult};
use rustc_middle::mir::*;
use rustc_middle::query::TyCtxtAt;
use rustc_middle::ty;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_span::def_id::DefId;
/// Macro for machine-specific `InterpError` without allocation.
/// (These will never be shown to the user, but they help diagnose ICEs.)
pub macro throw_machine_stop_str($($tt:tt)*) {{
// We make a new local type for it. The type itself does not carry any information,
// but its vtable (for the `MachineStopType` trait) does.
#[derive(Debug)]
struct Zst;
// Printing this type shows the desired string.
impl std::fmt::Display for Zst {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, $($tt)*)
}
}
impl rustc_middle::mir::interpret::MachineStopType for Zst {
fn diagnostic_message(&self) -> rustc_errors::DiagMessage {
self.to_string().into()
}
fn add_args(
self: Box<Self>,
_: &mut dyn FnMut(rustc_errors::DiagArgName, rustc_errors::DiagArgValue),
) {}
}
throw_machine_stop!(Zst)
}}
pub struct DummyMachine;
impl HasStaticRootDefId for DummyMachine {
fn static_def_id(&self) -> Option<rustc_hir::def_id::LocalDefId> {
None
}
}
impl<'mir, 'tcx: 'mir> interpret::Machine<'mir, 'tcx> for DummyMachine {
interpret::compile_time_machine!(<'mir, 'tcx>);
type MemoryKind = !;
const PANIC_ON_ALLOC_FAIL: bool = true;
#[inline(always)]
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
false // no reason to enforce alignment
}
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
false
}
fn before_access_global(
_tcx: TyCtxtAt<'tcx>,
_machine: &Self,
_alloc_id: AllocId,
alloc: ConstAllocation<'tcx>,
_static_def_id: Option<DefId>,
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
throw_machine_stop_str!("can't write to global");
}
// If the static allocation is mutable, then we can't const prop it as its content
// might be different at runtime.
if alloc.inner().mutability.is_mut() {
throw_machine_stop_str!("can't access mutable globals in ConstProp");
}
Ok(())
}
fn find_mir_or_eval_fn(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_instance: ty::Instance<'tcx>,
_abi: rustc_target::spec::abi::Abi,
_args: &[interpret::FnArg<'tcx, Self::Provenance>],
_destination: &interpret::MPlaceTy<'tcx, Self::Provenance>,
_target: Option<BasicBlock>,
_unwind: UnwindAction,
) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> {
unimplemented!()
}
fn panic_nounwind(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_msg: &str,
) -> interpret::InterpResult<'tcx> {
unimplemented!()
}
fn call_intrinsic(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_instance: ty::Instance<'tcx>,
_args: &[interpret::OpTy<'tcx, Self::Provenance>],
_destination: &interpret::MPlaceTy<'tcx, Self::Provenance>,
_target: Option<BasicBlock>,
_unwind: UnwindAction,
) -> interpret::InterpResult<'tcx> {
unimplemented!()
}
fn assert_panic(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_msg: &rustc_middle::mir::AssertMessage<'tcx>,
_unwind: UnwindAction,
) -> interpret::InterpResult<'tcx> {
unimplemented!()
}
fn binary_ptr_op(
ecx: &InterpCx<'mir, 'tcx, Self>,
bin_op: BinOp,
left: &interpret::ImmTy<'tcx, Self::Provenance>,
right: &interpret::ImmTy<'tcx, Self::Provenance>,
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
use rustc_middle::mir::BinOp::*;
Ok(match bin_op {
Eq | Ne | Lt | Le | Gt | Ge => {
// Types can differ, e.g. fn ptrs with different `for`.
assert_eq!(left.layout.abi, right.layout.abi);
let size = ecx.pointer_size();
// Just compare the bits. ScalarPairs are compared lexicographically.
// We thus always compare pairs and simply fill scalars up with 0.
// If the pointer has provenance, `to_bits` will return `Err` and we bail out.
let left = match **left {
Immediate::Scalar(l) => (l.to_bits(size)?, 0),
Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?),
Immediate::Uninit => panic!("we should never see uninit data here"),
};
let right = match **right {
Immediate::Scalar(r) => (r.to_bits(size)?, 0),
Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?),
Immediate::Uninit => panic!("we should never see uninit data here"),
};
let res = match bin_op {
Eq => left == right,
Ne => left != right,
Lt => left < right,
Le => left <= right,
Gt => left > right,
Ge => left >= right,
_ => bug!(),
};
(ImmTy::from_bool(res, *ecx.tcx), false)
}
// Some more operations are possible with atomics.
// The return value always has the provenance of the *left* operand.
Add | Sub | BitOr | BitAnd | BitXor => {
throw_machine_stop_str!("pointer arithmetic is not handled")
}
_ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
})
}
fn expose_ptr(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_ptr: interpret::Pointer<Self::Provenance>,
) -> interpret::InterpResult<'tcx> {
unimplemented!()
}
fn init_frame_extra(
_ecx: &mut InterpCx<'mir, 'tcx, Self>,
_frame: interpret::Frame<'mir, 'tcx, Self::Provenance>,
) -> interpret::InterpResult<
'tcx,
interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>,
> {
unimplemented!()
}
fn stack<'a>(
_ecx: &'a InterpCx<'mir, 'tcx, Self>,
) -> &'a [interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] {
// Return an empty stack instead of panicking, as `cur_span` uses it to evaluate constants.
&[]
}
fn stack_mut<'a>(
_ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
) -> &'a mut Vec<interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> {
unimplemented!()
}
}

View file

@ -3,7 +3,7 @@ use either::{Left, Right};
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{AllocId, ErrorHandled, InterpErrorInfo};
use rustc_middle::mir::{self, ConstAlloc, ConstValue};
use rustc_middle::query::TyCtxtAt;
use rustc_middle::query::{Key, TyCtxtAt};
use rustc_middle::traits::Reveal;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::print::with_no_trimmed_paths;
@ -243,6 +243,24 @@ pub(crate) fn turn_into_const_value<'tcx>(
op_to_const(&ecx, &mplace.into(), /* for diagnostics */ false)
}
/// Computes the tag (if any) for a given type and variant.
#[instrument(skip(tcx), level = "debug")]
pub fn tag_for_variant_provider<'tcx>(
tcx: TyCtxt<'tcx>,
(ty, variant_index): (Ty<'tcx>, abi::VariantIdx),
) -> Option<ty::ScalarInt> {
assert!(ty.is_enum());
let ecx = InterpCx::new(
tcx,
ty.default_span(tcx),
ty::ParamEnv::reveal_all(),
crate::const_eval::DummyMachine,
);
ecx.tag_for_variant(ty, variant_index).unwrap().map(|(tag, _tag_field)| tag)
}
#[instrument(skip(tcx), level = "debug")]
pub fn eval_to_const_value_raw_provider<'tcx>(
tcx: TyCtxt<'tcx>,

View file

@ -7,12 +7,14 @@ use rustc_middle::ty::{self, Ty};
use crate::interpret::format_interp_error;
mod dummy_machine;
mod error;
mod eval_queries;
mod fn_queries;
mod machine;
mod valtrees;
pub use dummy_machine::*;
pub use error::*;
pub use eval_queries::*;
pub use fn_queries::*;

View file

@ -2,7 +2,7 @@
use rustc_middle::mir;
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, ScalarInt, Ty};
use rustc_target::abi::{self, TagEncoding};
use rustc_target::abi::{VariantIdx, Variants};
@ -28,78 +28,27 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
throw_ub!(UninhabitedEnumVariantWritten(variant_index))
}
match dest.layout().variants {
abi::Variants::Single { index } => {
assert_eq!(index, variant_index);
}
abi::Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: tag_layout,
tag_field,
..
} => {
match self.tag_for_variant(dest.layout().ty, variant_index)? {
Some((tag, tag_field)) => {
// No need to validate that the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
let discr_val = dest
.layout()
.ty
.discriminant_for_variant(*self.tcx, variant_index)
.unwrap()
.val;
// raw discriminants for enums are isize or bigger during
// their computation, but the in-memory tag is the smallest possible
// representation
let size = tag_layout.size(self);
let tag_val = size.truncate(discr_val);
// `TyAndLayout::for_variant()` call earlier already checks the
// variant is valid.
let tag_dest = self.project_field(dest, tag_field)?;
self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
self.write_scalar(tag, &tag_dest)
}
abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag: tag_layout,
tag_field,
..
} => {
// No need to validate that the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
if variant_index != untagged_variant {
let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
.checked_sub(variants_start)
.expect("overflow computing relative variant idx");
// We need to use machine arithmetic when taking into account `niche_start`:
// tag_val = variant_index_relative + niche_start_val
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
let variant_index_relative_val =
ImmTy::from_uint(variant_index_relative, tag_layout);
let tag_val = self.wrapping_binary_op(
mir::BinOp::Add,
&variant_index_relative_val,
&niche_start_val,
)?;
// Write result.
let niche_dest = self.project_field(dest, tag_field)?;
self.write_immediate(*tag_val, &niche_dest)?;
} else {
// The untagged variant is implicitly encoded simply by having a value that is
// outside the niche variants. But what if the data stored here does not
// actually encode this variant? That would be bad! So let's double-check...
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
if actual_variant != variant_index {
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
}
None => {
// No need to write the tag here, because an untagged variant is
// implicitly encoded. For `Niche`-optimized enums, it's by
// simply by having a value that is outside the niche variants.
// But what if the data stored here does not actually encode
// this variant? That would be bad! So let's double-check...
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
if actual_variant != variant_index {
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
}
Ok(())
}
}
Ok(())
}
/// Read discriminant, return the runtime value as well as the variant index.
@ -277,4 +226,77 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
};
Ok(ImmTy::from_scalar(discr_value, discr_layout))
}
/// Computes the tag value and its field number (if any) of a given variant
/// of type `ty`.
pub(crate) fn tag_for_variant(
&self,
ty: Ty<'tcx>,
variant_index: VariantIdx,
) -> InterpResult<'tcx, Option<(ScalarInt, usize)>> {
match self.layout_of(ty)?.variants {
abi::Variants::Single { index } => {
assert_eq!(index, variant_index);
Ok(None)
}
abi::Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: tag_layout,
tag_field,
..
} => {
// raw discriminants for enums are isize or bigger during
// their computation, but the in-memory tag is the smallest possible
// representation
let discr = self.discriminant_for_variant(ty, variant_index)?;
let discr_size = discr.layout.size;
let discr_val = discr.to_scalar().to_bits(discr_size)?;
let tag_size = tag_layout.size(self);
let tag_val = tag_size.truncate(discr_val);
let tag = ScalarInt::try_from_uint(tag_val, tag_size).unwrap();
Ok(Some((tag, tag_field)))
}
abi::Variants::Multiple {
tag_encoding: TagEncoding::Niche { untagged_variant, .. },
..
} if untagged_variant == variant_index => {
// The untagged variant is implicitly encoded simply by having a
// value that is outside the niche variants.
Ok(None)
}
abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag: tag_layout,
tag_field,
..
} => {
assert!(variant_index != untagged_variant);
let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
.checked_sub(variants_start)
.expect("overflow computing relative variant idx");
// We need to use machine arithmetic when taking into account `niche_start`:
// tag_val = variant_index_relative + niche_start_val
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
let variant_index_relative_val =
ImmTy::from_uint(variant_index_relative, tag_layout);
let tag = self
.wrapping_binary_op(
mir::BinOp::Add,
&variant_index_relative_val,
&niche_start_val,
)?
.to_scalar()
.try_to_int()
.unwrap();
Ok(Some((tag, tag_field)))
}
}
}
}

View file

@ -40,6 +40,7 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
pub fn provide(providers: &mut Providers) {
const_eval::provide(providers);
providers.tag_for_variant = const_eval::tag_for_variant_provider;
providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
providers.eval_static_initializer = const_eval::eval_static_initializer_provider;