interpret: refactor dyn trait handling

We can check that the vtable is for the right trait very early, and then just pass the type around.
This commit is contained in:
Ralf Jung 2024-06-10 17:24:36 +02:00
parent 0de24a5177
commit af4d6c74ef
7 changed files with 90 additions and 112 deletions

View file

@ -383,7 +383,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
match (&src_pointee_ty.kind(), &dest_pointee_ty.kind()) { match (&src_pointee_ty.kind(), &dest_pointee_ty.kind()) {
(&ty::Array(_, length), &ty::Slice(_)) => { (&ty::Array(_, length), &ty::Slice(_)) => {
let ptr = self.read_pointer(src)?; let ptr = self.read_pointer(src)?;
// u64 cast is from usize to u64, which is always good
let val = Immediate::new_slice( let val = Immediate::new_slice(
ptr, ptr,
length.eval_target_usize(*self.tcx, self.param_env), length.eval_target_usize(*self.tcx, self.param_env),
@ -401,13 +400,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let (old_data, old_vptr) = val.to_scalar_pair(); let (old_data, old_vptr) = val.to_scalar_pair();
let old_data = old_data.to_pointer(self)?; let old_data = old_data.to_pointer(self)?;
let old_vptr = old_vptr.to_pointer(self)?; let old_vptr = old_vptr.to_pointer(self)?;
let (ty, old_trait) = self.get_ptr_vtable(old_vptr)?; let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
if old_trait != data_a.principal() {
throw_ub!(InvalidVTableTrait {
expected_trait: data_a,
vtable_trait: old_trait,
});
}
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?; let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest) self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
} }

View file

@ -867,19 +867,28 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
.ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into()) .ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into())
} }
pub fn get_ptr_vtable( /// Get the dynamic type of the given vtable pointer.
/// If `expected_trait` is `Some`, it must be a vtable for the given trait.
pub fn get_ptr_vtable_ty(
&self, &self,
ptr: Pointer<Option<M::Provenance>>, ptr: Pointer<Option<M::Provenance>>,
) -> InterpResult<'tcx, (Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>)> { expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
) -> InterpResult<'tcx, Ty<'tcx>> {
trace!("get_ptr_vtable({:?})", ptr); trace!("get_ptr_vtable({:?})", ptr);
let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?; let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?;
if offset.bytes() != 0 { if offset.bytes() != 0 {
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))) throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
} }
match self.tcx.try_get_global_alloc(alloc_id) { let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
Some(GlobalAlloc::VTable(ty, trait_ref)) => Ok((ty, trait_ref)), else {
_ => throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))), throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
};
if let Some(expected_trait) = expected_trait {
if vtable_trait != expected_trait.principal() {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
} }
Ok(ty)
} }
pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> { pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> {

View file

@ -9,7 +9,6 @@ use tracing::{instrument, trace};
use rustc_ast::Mutability; use rustc_ast::Mutability;
use rustc_middle::mir; use rustc_middle::mir;
use rustc_middle::ty;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::Ty; use rustc_middle::ty::Ty;
use rustc_middle::{bug, span_bug}; use rustc_middle::{bug, span_bug};
@ -1017,54 +1016,6 @@ where
let layout = self.layout_of(raw.ty)?; let layout = self.layout_of(raw.ty)?;
Ok(self.ptr_to_mplace(ptr.into(), layout)) Ok(self.ptr_to_mplace(ptr.into(), layout))
} }
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
/// Aso returns the vtable.
pub(super) fn unpack_dyn_trait(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, Pointer<Option<M::Provenance>>)> {
assert!(
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
"`unpack_dyn_trait` only makes sense on `dyn*` types"
);
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
if expected_trait.principal() != vtable_trait {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
// This is a kind of transmute, from a place with unsized type and metadata to
// a place with sized type and no metadata.
let layout = self.layout_of(ty)?;
let mplace =
MPlaceTy { mplace: MemPlace { meta: MemPlaceMeta::None, ..mplace.mplace }, layout };
Ok((mplace, vtable))
}
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
/// Also returns the vtable.
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
&self,
val: &P,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, (P, Pointer<Option<M::Provenance>>)> {
assert!(
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
"`unpack_dyn_star` only makes sense on `dyn*` types"
);
let data = self.project_field(val, 0)?;
let vtable = self.project_field(val, 1)?;
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
if expected_trait.principal() != vtable_trait {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
// `data` is already the right thing but has the wrong type. So we transmute it.
let layout = self.layout_of(ty)?;
let data = data.transmute(layout, self)?;
Ok((data, vtable))
}
} }
// Some nodes are used a lot. Make sure they don't unintentionally get bigger. // Some nodes are used a lot. Make sure they don't unintentionally get bigger.

View file

@ -1,6 +1,7 @@
use std::borrow::Cow; use std::borrow::Cow;
use either::Either; use either::Either;
use rustc_middle::ty::TyCtxt;
use tracing::trace; use tracing::trace;
use rustc_middle::span_bug; use rustc_middle::span_bug;
@ -827,20 +828,19 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}; };
// Obtain the underlying trait we are working on, and the adjusted receiver argument. // Obtain the underlying trait we are working on, and the adjusted receiver argument.
let (vptr, dyn_ty, adjusted_receiver) = if let ty::Dynamic(data, _, ty::DynStar) = let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
receiver_place.layout.ty.kind() receiver_place.layout.ty.kind()
{ {
let (recv, vptr) = self.unpack_dyn_star(&receiver_place, data)?; let recv = self.unpack_dyn_star(&receiver_place, data)?;
let (dyn_ty, _dyn_trait) = self.get_ptr_vtable(vptr)?;
(vptr, dyn_ty, recv.ptr()) (data.principal(), recv.layout.ty, recv.ptr())
} else { } else {
// Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`. // Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`.
// (For that reason we also cannot use `unpack_dyn_trait`.) // (For that reason we also cannot use `unpack_dyn_trait`.)
let receiver_tail = self let receiver_tail = self
.tcx .tcx
.struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env); .struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env);
let ty::Dynamic(data, _, ty::Dyn) = receiver_tail.kind() else { let ty::Dynamic(receiver_trait, _, ty::Dyn) = receiver_tail.kind() else {
span_bug!( span_bug!(
self.cur_span(), self.cur_span(),
"dynamic call on non-`dyn` type {}", "dynamic call on non-`dyn` type {}",
@ -851,25 +851,24 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
// Get the required information from the vtable. // Get the required information from the vtable.
let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?; let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?;
let (dyn_ty, dyn_trait) = self.get_ptr_vtable(vptr)?; let dyn_ty = self.get_ptr_vtable_ty(vptr, Some(receiver_trait))?;
if dyn_trait != data.principal() {
throw_ub!(InvalidVTableTrait {
expected_trait: data,
vtable_trait: dyn_trait,
});
}
// It might be surprising that we use a pointer as the receiver even if this // It might be surprising that we use a pointer as the receiver even if this
// is a by-val case; this works because by-val passing of an unsized `dyn // is a by-val case; this works because by-val passing of an unsized `dyn
// Trait` to a function is actually desugared to a pointer. // Trait` to a function is actually desugared to a pointer.
(vptr, dyn_ty, receiver_place.ptr()) (receiver_trait.principal(), dyn_ty, receiver_place.ptr())
}; };
// Now determine the actual method to call. We can do that in two different ways and // Now determine the actual method to call. We can do that in two different ways and
// compare them to ensure everything fits. // compare them to ensure everything fits.
let Some(ty::VtblEntry::Method(fn_inst)) = let vtable_entries = if let Some(dyn_trait) = dyn_trait {
self.get_vtable_entries(vptr)?.get(idx).copied() let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
else { let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
};
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this // FIXME(fee1-dead) these could be variants of the UB info enum instead of this
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method); throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
}; };
@ -898,7 +897,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty); let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty);
args[0] = FnArg::Copy( args[0] = FnArg::Copy(
ImmTy::from_immediate( ImmTy::from_immediate(
Scalar::from_maybe_pointer(adjusted_receiver, self).into(), Scalar::from_maybe_pointer(adjusted_recv, self).into(),
self.layout_of(receiver_ty)?, self.layout_of(receiver_ty)?,
) )
.into(), .into(),
@ -974,11 +973,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let place = match place.layout.ty.kind() { let place = match place.layout.ty.kind() {
ty::Dynamic(data, _, ty::Dyn) => { ty::Dynamic(data, _, ty::Dyn) => {
// Dropping a trait object. Need to find actual drop fn. // Dropping a trait object. Need to find actual drop fn.
self.unpack_dyn_trait(&place, data)?.0 self.unpack_dyn_trait(&place, data)?
} }
ty::Dynamic(data, _, ty::DynStar) => { ty::Dynamic(data, _, ty::DynStar) => {
// Dropping a `dyn*`. Need to find actual drop fn. // Dropping a `dyn*`. Need to find actual drop fn.
self.unpack_dyn_star(&place, data)?.0 self.unpack_dyn_star(&place, data)?
} }
_ => { _ => {
debug_assert_eq!( debug_assert_eq!(

View file

@ -1,11 +1,11 @@
use rustc_middle::mir::interpret::{InterpResult, Pointer}; use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_middle::ty::{self, Ty};
use rustc_target::abi::{Align, Size}; use rustc_target::abi::{Align, Size};
use tracing::trace; use tracing::trace;
use super::util::ensure_monomorphic_enough; use super::util::ensure_monomorphic_enough;
use super::{InterpCx, Machine}; use super::{InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};
impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
/// Creates a dynamic vtable for the given type and vtable origin. This is used only for /// Creates a dynamic vtable for the given type and vtable origin. This is used only for
@ -33,28 +33,58 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok(vtable_ptr.into()) Ok(vtable_ptr.into())
} }
/// Returns a high-level representation of the entries of the given vtable.
pub fn get_vtable_entries(
&self,
vtable: Pointer<Option<M::Provenance>>,
) -> InterpResult<'tcx, &'tcx [ty::VtblEntry<'tcx>]> {
let (ty, poly_trait_ref) = self.get_ptr_vtable(vtable)?;
Ok(if let Some(poly_trait_ref) = poly_trait_ref {
let trait_ref = poly_trait_ref.with_self_ty(*self.tcx, ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
})
}
pub fn get_vtable_size_and_align( pub fn get_vtable_size_and_align(
&self, &self,
vtable: Pointer<Option<M::Provenance>>, vtable: Pointer<Option<M::Provenance>>,
) -> InterpResult<'tcx, (Size, Align)> { ) -> InterpResult<'tcx, (Size, Align)> {
let (ty, _trait_ref) = self.get_ptr_vtable(vtable)?; let ty = self.get_ptr_vtable_ty(vtable, None)?;
let layout = self.layout_of(ty)?; let layout = self.layout_of(ty)?;
assert!(layout.is_sized(), "there are no vtables for unsized types"); assert!(layout.is_sized(), "there are no vtables for unsized types");
Ok((layout.size, layout.align.abi)) Ok((layout.size, layout.align.abi))
} }
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
pub(super) fn unpack_dyn_trait(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, MPlaceTy<'tcx, M::Provenance>> {
assert!(
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
"`unpack_dyn_trait` only makes sense on `dyn*` types"
);
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
// This is a kind of transmute, from a place with unsized type and metadata to
// a place with sized type and no metadata.
let layout = self.layout_of(ty)?;
let mplace = mplace.offset_with_meta(
Size::ZERO,
OffsetMode::Wrapping,
MemPlaceMeta::None,
layout,
self,
)?;
Ok(mplace)
}
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
&self,
val: &P,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, P> {
assert!(
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
"`unpack_dyn_star` only makes sense on `dyn*` types"
);
let data = self.project_field(val, 0)?;
let vtable = self.project_field(val, 1)?;
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
// `data` is already the right thing but has the wrong type. So we transmute it.
let layout = self.layout_of(ty)?;
let data = data.transmute(layout, self)?;
Ok(data)
}
} }

View file

@ -343,20 +343,16 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
match tail.kind() { match tail.kind() {
ty::Dynamic(data, _, ty::Dyn) => { ty::Dynamic(data, _, ty::Dyn) => {
let vtable = meta.unwrap_meta().to_pointer(self.ecx)?; let vtable = meta.unwrap_meta().to_pointer(self.ecx)?;
// Make sure it is a genuine vtable pointer. // Make sure it is a genuine vtable pointer for the right trait.
let (_dyn_ty, dyn_trait) = try_validation!( try_validation!(
self.ecx.get_ptr_vtable(vtable), self.ecx.get_ptr_vtable_ty(vtable, Some(data)),
self.path, self.path,
Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) => Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) =>
InvalidVTablePtr { value: format!("{vtable}") } InvalidVTablePtr { value: format!("{vtable}") },
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
},
); );
// Make sure it is for the right trait.
if dyn_trait != data.principal() {
throw_validation_failure!(
self.path,
InvalidMetaWrongTrait { expected_trait: data, vtable_trait: dyn_trait }
);
}
} }
ty::Slice(..) | ty::Str => { ty::Slice(..) | ty::Str => {
let _len = meta.unwrap_meta().to_target_usize(self.ecx)?; let _len = meta.unwrap_meta().to_target_usize(self.ecx)?;

View file

@ -95,7 +95,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized {
// unsized values are never immediate, so we can assert_mem_place // unsized values are never immediate, so we can assert_mem_place
let op = v.to_op(self.ecx())?; let op = v.to_op(self.ecx())?;
let dest = op.assert_mem_place(); let dest = op.assert_mem_place();
let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?.0; let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?;
trace!("walk_value: dyn object layout: {:#?}", inner_mplace.layout); trace!("walk_value: dyn object layout: {:#?}", inner_mplace.layout);
// recurse with the inner type // recurse with the inner type
return self.visit_field(v, 0, &inner_mplace.into()); return self.visit_field(v, 0, &inner_mplace.into());
@ -104,7 +104,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized {
// DynStar types. Very different from a dyn type (but strangely part of the // DynStar types. Very different from a dyn type (but strangely part of the
// same variant in `TyKind`): These are pairs where the 2nd component is the // same variant in `TyKind`): These are pairs where the 2nd component is the
// vtable, and the first component is the data (which must be ptr-sized). // vtable, and the first component is the data (which must be ptr-sized).
let data = self.ecx().unpack_dyn_star(v, data)?.0; let data = self.ecx().unpack_dyn_star(v, data)?;
return self.visit_field(v, 0, &data); return self.visit_field(v, 0, &data);
} }
// Slices do not need special handling here: they have `Array` field // Slices do not need special handling here: they have `Array` field