1
Fork 0

Handle custom discriminant values and detect invalid discriminants.

This commit is contained in:
Scott Olson 2016-03-28 21:08:08 -06:00
parent 71b94c9a5d
commit 63fdd46f9a
3 changed files with 48 additions and 10 deletions

View file

@ -5,6 +5,7 @@ use std::fmt;
pub enum EvalError { pub enum EvalError {
DanglingPointerDeref, DanglingPointerDeref,
InvalidBool, InvalidBool,
InvalidDiscriminant,
PointerOutOfBounds, PointerOutOfBounds,
ReadPointerAsBytes, ReadPointerAsBytes,
ReadBytesAsPointer, ReadBytesAsPointer,
@ -21,6 +22,8 @@ impl Error for EvalError {
"dangling pointer was dereferenced", "dangling pointer was dereferenced",
EvalError::InvalidBool => EvalError::InvalidBool =>
"invalid boolean value read", "invalid boolean value read",
EvalError::InvalidDiscriminant =>
"invalid enum discriminant value read",
EvalError::PointerOutOfBounds => EvalError::PointerOutOfBounds =>
"pointer offset outside bounds of allocation", "pointer offset outside bounds of allocation",
EvalError::ReadPointerAsBytes => EvalError::ReadPointerAsBytes =>

View file

@ -228,7 +228,7 @@ impl<'a, 'tcx: 'a, 'arena> Interpreter<'a, 'tcx, 'arena> {
TerminatorTarget::Block(target_block) TerminatorTarget::Block(target_block)
} }
Switch { ref discr, ref targets, .. } => { Switch { ref discr, ref targets, adt_def } => {
let adt_ptr = try!(self.eval_lvalue(discr)).to_ptr(); let adt_ptr = try!(self.eval_lvalue(discr)).to_ptr();
let adt_repr = self.lvalue_repr(discr); let adt_repr = self.lvalue_repr(discr);
let discr_size = match *adt_repr { let discr_size = match *adt_repr {
@ -236,7 +236,14 @@ impl<'a, 'tcx: 'a, 'arena> Interpreter<'a, 'tcx, 'arena> {
_ => panic!("attmpted to switch on non-aggregate type"), _ => panic!("attmpted to switch on non-aggregate type"),
}; };
let discr_val = try!(self.memory.read_uint(adt_ptr, discr_size)); let discr_val = try!(self.memory.read_uint(adt_ptr, discr_size));
TerminatorTarget::Block(targets[discr_val as usize])
let matching = adt_def.variants.iter()
.position(|v| discr_val == v.disr_val.to_u64_unchecked());
match matching {
Some(i) => TerminatorTarget::Block(targets[i]),
None => return Err(EvalError::InvalidDiscriminant),
}
} }
Call { ref func, ref args, ref destination, .. } => { Call { ref func, ref args, ref destination, .. } => {
@ -481,13 +488,18 @@ impl<'a, 'tcx: 'a, 'arena> Interpreter<'a, 'tcx, 'arena> {
Ok(TerminatorTarget::Call) Ok(TerminatorTarget::Call)
} }
fn assign_to_aggregate(&mut self, dest: Pointer, dest_repr: &Repr, variant: usize, fn assign_to_aggregate(
operands: &[mir::Operand<'tcx>]) -> EvalResult<()> { &mut self,
dest: Pointer,
dest_repr: &Repr,
variant: usize,
discr: Option<u64>,
operands: &[mir::Operand<'tcx>],
) -> EvalResult<()> {
match *dest_repr { match *dest_repr {
Repr::Aggregate { discr_size, ref variants, .. } => { Repr::Aggregate { discr_size, ref variants, .. } => {
if discr_size > 0 { if discr_size > 0 {
let discr = variant as u64; try!(self.memory.write_uint(dest, discr.unwrap(), discr_size));
try!(self.memory.write_uint(dest, discr, discr_size));
} }
let after_discr = dest.offset(discr_size as isize); let after_discr = dest.offset(discr_size as isize);
for (field, operand) in variants[variant].iter().zip(operands) { for (field, operand) in variants[variant].iter().zip(operands) {
@ -538,10 +550,12 @@ impl<'a, 'tcx: 'a, 'arena> Interpreter<'a, 'tcx, 'arena> {
use rustc::mir::repr::AggregateKind::*; use rustc::mir::repr::AggregateKind::*;
match *kind { match *kind {
Tuple | Closure(..) => Tuple | Closure(..) =>
try!(self.assign_to_aggregate(dest, &dest_repr, 0, operands)), try!(self.assign_to_aggregate(dest, &dest_repr, 0, None, operands)),
Adt(_, variant_idx, _) => Adt(adt_def, variant, _) => {
try!(self.assign_to_aggregate(dest, &dest_repr, variant_idx, operands)), let discr = Some(adt_def.variants[variant].disr_val.to_u64_unchecked());
try!(self.assign_to_aggregate(dest, &dest_repr, variant, discr, operands));
}
Vec => if let Repr::Array { elem_size, length } = *dest_repr { Vec => if let Repr::Array { elem_size, length } = *dest_repr {
assert_eq!(length, operands.len()); assert_eq!(length, operands.len());
@ -668,7 +682,7 @@ impl<'a, 'tcx: 'a, 'arena> Interpreter<'a, 'tcx, 'arena> {
use rustc::mir::tcx::LvalueTy; use rustc::mir::tcx::LvalueTy;
match self.mir().lvalue_ty(self.tcx, lvalue) { match self.mir().lvalue_ty(self.tcx, lvalue) {
LvalueTy::Ty { ty } => self.ty_to_repr(ty), LvalueTy::Ty { ty } => self.ty_to_repr(ty),
LvalueTy::Downcast { ref adt_def, substs, variant_index } => { LvalueTy::Downcast { adt_def, substs, variant_index } => {
let field_tys = adt_def.variants[variant_index].fields.iter() let field_tys = adt_def.variants[variant_index].fields.iter()
.map(|f| f.ty(self.tcx, substs)); .map(|f| f.ty(self.tcx, substs));
self.repr_arena.alloc(self.make_aggregate_repr(iter::once(field_tys))) self.repr_arena.alloc(self.make_aggregate_repr(iter::once(field_tys)))

21
test/c_enums.rs Executable file
View file

@ -0,0 +1,21 @@
#![feature(custom_attribute)]
#![allow(dead_code, unused_attributes)]
enum Foo {
Bar = 42,
Baz,
Quux = 100,
}
#[miri_run]
fn foo() -> [u8; 3] {
[Foo::Bar as u8, Foo::Baz as u8, Foo::Quux as u8]
}
#[miri_run]
fn unsafe_match() -> bool {
match unsafe { std::mem::transmute::<u8, Foo>(43) } {
Foo::Baz => true,
_ => false,
}
}