Check alignment of pointers only when read/written through

This commit is contained in:
Ben Kimock 2023-07-04 14:23:16 -04:00
parent 7314873326
commit f9bd7dabcf
10 changed files with 143 additions and 59 deletions

View file

@ -1,13 +1,12 @@
use crate::MirPass;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::mir::{
interpret::Scalar,
visit::{PlaceContext, Visitor},
visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor},
};
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut};
use rustc_session::Session;
pub struct CheckAlignment;
@ -30,7 +29,12 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
// This pass inserts new blocks. Each insertion changes the Location for all
// statements/blocks after. Iterating or visiting the MIR in order would require updating
// our current location after every insertion. By iterating backwards, we dodge this issue:
// The only Locations that an insertion changes have already been handled.
for block in (0..basic_blocks.len()).rev() {
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
@ -38,22 +42,19 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;
let mut finder = PointerFinder {
local_decls,
tcx,
pointers: Vec::new(),
def_id: body.source.def_id(),
};
for (pointer, pointee_ty) in finder.find_pointers(statement) {
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
let mut finder =
PointerFinder { tcx, local_decls, param_env, pointers: Vec::new() };
finder.visit_statement(statement, location);
for (local, ty) in finder.pointers {
debug!("Inserting alignment check for {:?}", ty);
let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
pointer,
pointee_ty,
local,
ty,
source_info,
new_block,
);
@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
}
}
impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
self.pointers.clear();
self.visit_statement(statement, Location::START);
core::mem::take(&mut self.pointers)
}
}
struct PointerFinder<'tcx, 'a> {
local_decls: &'a mut LocalDecls<'tcx>,
tcx: TyCtxt<'tcx>,
def_id: DefId,
local_decls: &'a mut LocalDecls<'tcx>,
param_env: ParamEnv<'tcx>,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}
impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
if let Rvalue::AddressOf(..) = rvalue {
// Ignore dereferences inside of an AddressOf
return;
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
// We want to only check reads and writes to Places, so we specifically exclude
// Borrows and AddressOf.
match context {
PlaceContext::MutatingUse(
MutatingUseContext::Store
| MutatingUseContext::AsmOutput
| MutatingUseContext::Call
| MutatingUseContext::Yield
| MutatingUseContext::Drop,
) => {}
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}
self.super_rvalue(rvalue, location);
}
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if let PlaceContext::NonUse(_) = context {
return;
}
if !place.is_indirect() {
return;
}
// Since Deref projections must come first and only once, the pointer for an indirect place
// is the Local that the Place is based on.
let pointer = Place::from(place.local);
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
let pointer_ty = self.local_decls[place.local].ty;
// We only want to check unsafe pointers
// We only want to check places based on unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place);
return;
}
let Some(pointee) = pointer_ty.builtin_deref(true) else {
debug!("Indirect but no builtin deref: {:?}", pointer_ty);
let pointee_ty =
pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
// Ideally we'd support this in the future, but for now we are limited to sized types.
if !pointee_ty.is_sized(self.tcx, self.param_env) {
debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty);
return;
}
// Try to detect types we are sure have an alignment of 1 and skip the check
// We don't need to look for str and slices, we already rejected unsized types above
let element_ty = match pointee_ty.kind() {
ty::Array(ty, _) => *ty,
_ => pointee_ty,
};
let mut pointee_ty = pointee.ty;
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
pointee_ty = pointee_ty.sequence_element_type(self.tcx);
}
if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) {
debug!("Trivially aligned place type: {:?}", pointee_ty);
return;
}
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
.contains(&pointee_ty)
{
debug!("Trivially aligned pointee type: {:?}", pointer_ty);
return;
}
// Ensure that this place is based on an aligned pointer.
self.pointers.push((pointer, pointee_ty));
self.pointers.push((pointer, pointee_ty))
self.super_place(place, context, location);
}
}