Enable GVN for AggregateKind::RawPtr & UnOp::PtrMetadata

This commit is contained in:
Scott McMurray 2024-05-12 00:22:42 -07:00
parent 003a902792
commit 021ccf6c4e
14 changed files with 519 additions and 26 deletions

View file

@ -83,8 +83,8 @@
//! that contain `AllocId`s.
use rustc_const_eval::const_eval::DummyMachine;
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar};
use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemPlaceMeta, MemoryKind};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable, Scalar};
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_hir::def::DefKind;
@ -99,7 +99,7 @@ use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
use rustc_target::abi::{self, Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
use smallvec::SmallVec;
use std::borrow::Cow;
@ -177,6 +177,12 @@ enum AggregateTy<'tcx> {
Array,
Tuple,
Def(DefId, ty::GenericArgsRef<'tcx>),
RawPtr {
/// Needed for cast propagation.
data_pointer_ty: Ty<'tcx>,
/// The data pointer can be anything thin, so doesn't determine the output.
output_pointer_ty: Ty<'tcx>,
},
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
@ -385,11 +391,22 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
AggregateTy::Def(def_id, args) => {
self.tcx.type_of(def_id).instantiate(self.tcx, args)
}
AggregateTy::RawPtr { output_pointer_ty, .. } => output_pointer_ty,
};
let variant = if ty.is_enum() { Some(variant) } else { None };
let ty = self.ecx.layout_of(ty).ok()?;
if ty.is_zst() {
ImmTy::uninit(ty).into()
} else if matches!(kind, AggregateTy::RawPtr { .. }) {
// Pointers don't have fields, so don't `project_field` them.
let data = self.ecx.read_pointer(fields[0]).ok()?;
let meta = if fields[1].layout.is_zst() {
MemPlaceMeta::None
} else {
MemPlaceMeta::Meta(self.ecx.read_scalar(fields[1]).ok()?)
};
let ptr_imm = Immediate::new_pointer_with_meta(data, meta, &self.ecx);
ImmTy::from_immediate(ptr_imm, ty).into()
} else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?;
let variant_dest = if let Some(variant) = variant {
@ -862,10 +879,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
rvalue: &mut Rvalue<'tcx>,
location: Location,
) -> Option<VnIndex> {
let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() };
let Rvalue::Aggregate(box ref kind, ref mut field_ops) = *rvalue else { bug!() };
let tcx = self.tcx;
if fields.is_empty() {
if field_ops.is_empty() {
let is_zst = match *kind {
AggregateKind::Array(..)
| AggregateKind::Tuple
@ -884,13 +901,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}
let (ty, variant_index) = match *kind {
let (mut ty, variant_index) = match *kind {
AggregateKind::Array(..) => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Array, FIRST_VARIANT)
}
AggregateKind::Tuple => {
assert!(!fields.is_empty());
assert!(!field_ops.is_empty());
(AggregateTy::Tuple, FIRST_VARIANT)
}
AggregateKind::Closure(did, args)
@ -901,15 +918,49 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
// Do not track unions.
AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
// FIXME: Do the extra work to GVN `from_raw_parts`
AggregateKind::RawPtr(..) => return None,
AggregateKind::RawPtr(pointee_ty, mtbl) => {
assert_eq!(field_ops.len(), 2);
let data_pointer_ty = field_ops[FieldIdx::ZERO].ty(self.local_decls, self.tcx);
let output_pointer_ty = Ty::new_ptr(self.tcx, pointee_ty, mtbl);
(AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty }, FIRST_VARIANT)
}
};
let fields: Option<Vec<_>> = fields
let fields: Option<Vec<_>> = field_ops
.iter_mut()
.map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
.collect();
let fields = fields?;
let mut fields = fields?;
if let AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty } = &mut ty {
let mut was_updated = false;
// Any thin pointer of matching mutability is fine as the data pointer.
while let Value::Cast {
kind: CastKind::PtrToPtr,
value: cast_value,
from: cast_from,
to: _,
} = self.get(fields[0])
&& let ty::RawPtr(from_pointee_ty, from_mtbl) = cast_from.kind()
&& let ty::RawPtr(_, output_mtbl) = output_pointer_ty.kind()
&& from_mtbl == output_mtbl
&& from_pointee_ty.is_sized(self.tcx, self.param_env)
{
fields[0] = *cast_value;
*data_pointer_ty = *cast_from;
was_updated = true;
}
if was_updated {
if let Some(const_) = self.try_as_constant(fields[0]) {
field_ops[FieldIdx::ZERO] = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(fields[0], location) {
field_ops[FieldIdx::ZERO] = Operand::Copy(Place::from(local));
self.reused_locals.insert(local);
}
}
}
if let AggregateTy::Array = ty
&& fields.len() > 4
@ -941,6 +992,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
(UnOp::Not, Value::BinaryOp(BinOp::Ne, lhs, rhs)) => {
Value::BinaryOp(BinOp::Eq, *lhs, *rhs)
}
(UnOp::PtrMetadata, Value::Aggregate(AggregateTy::RawPtr { .. }, _, fields)) => {
return Some(fields[1]);
}
_ => return None,
};
@ -1092,6 +1146,23 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
return self.new_opaque();
}
let mut was_updated = false;
// If that cast just casts away the metadata again,
if let PtrToPtr = kind
&& let Value::Aggregate(AggregateTy::RawPtr { data_pointer_ty, .. }, _, fields) =
self.get(value)
&& let ty::RawPtr(to_pointee, _) = to.kind()
&& to_pointee.is_sized(self.tcx, self.param_env)
{
from = *data_pointer_ty;
value = fields[0];
was_updated = true;
if *data_pointer_ty == to {
return Some(fields[0]);
}
}
if let PtrToPtr | PointerCoercion(MutToConstPointer) = kind
&& let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } =
*self.get(value)
@ -1100,9 +1171,13 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
from = inner_from;
value = inner_value;
*kind = PtrToPtr;
was_updated = true;
if inner_from == to {
return Some(inner_value);
}
}
if was_updated {
if let Some(const_) = self.try_as_constant(value) {
*operand = Operand::Constant(Box::new(const_));
} else if let Some(local) = self.try_as_local(value, location) {