1
Fork 0

Improve mir interpreter performance by caching

This commit is contained in:
hkalbasi 2023-08-04 16:04:40 +03:30
parent e37ec7262c
commit 3115d6988f
2 changed files with 233 additions and 139 deletions

View file

@ -1,6 +1,13 @@
//! This module provides a MIR interpreter, which is used in const eval. //! This module provides a MIR interpreter, which is used in const eval.
use std::{borrow::Cow, cell::RefCell, collections::HashMap, fmt::Write, iter, mem, ops::Range}; use std::{
borrow::Cow,
cell::RefCell,
collections::{HashMap, HashSet},
fmt::Write,
iter, mem,
ops::Range,
};
use base_db::{CrateId, FileId}; use base_db::{CrateId, FileId};
use chalk_ir::Mutability; use chalk_ir::Mutability;
@ -39,7 +46,8 @@ use crate::{
use super::{ use super::{
return_slot, AggregateKind, BasicBlockId, BinOp, CastKind, LocalId, MirBody, MirLowerError, return_slot, AggregateKind, BasicBlockId, BinOp, CastKind, LocalId, MirBody, MirLowerError,
MirSpan, Operand, Place, ProjectionElem, Rvalue, StatementKind, TerminatorKind, UnOp, MirSpan, Operand, Place, PlaceElem, ProjectionElem, Rvalue, StatementKind, TerminatorKind,
UnOp,
}; };
mod shim; mod shim;
@ -120,13 +128,18 @@ impl TlsData {
} }
struct StackFrame { struct StackFrame {
body: Arc<MirBody>,
locals: Locals, locals: Locals,
destination: Option<BasicBlockId>, destination: Option<BasicBlockId>,
prev_stack_ptr: usize, prev_stack_ptr: usize,
span: (MirSpan, DefWithBodyId), span: (MirSpan, DefWithBodyId),
} }
#[derive(Clone)]
enum MirOrDynIndex {
Mir(Arc<MirBody>),
Dyn(usize),
}
pub struct Evaluator<'a> { pub struct Evaluator<'a> {
db: &'a dyn HirDatabase, db: &'a dyn HirDatabase,
trait_env: Arc<TraitEnvironment>, trait_env: Arc<TraitEnvironment>,
@ -145,6 +158,17 @@ pub struct Evaluator<'a> {
stdout: Vec<u8>, stdout: Vec<u8>,
stderr: Vec<u8>, stderr: Vec<u8>,
layout_cache: RefCell<FxHashMap<Ty, Arc<Layout>>>, layout_cache: RefCell<FxHashMap<Ty, Arc<Layout>>>,
projected_ty_cache: RefCell<FxHashMap<(Ty, PlaceElem), Ty>>,
not_special_fn_cache: RefCell<FxHashSet<FunctionId>>,
mir_or_dyn_index_cache: RefCell<FxHashMap<(FunctionId, Substitution), MirOrDynIndex>>,
/// Constantly dropping and creating `Locals` is very costly. We store
/// old locals that we normaly want to drop here, to reuse their allocations
/// later.
unused_locals_store: RefCell<FxHashMap<DefWithBodyId, Vec<Locals>>>,
cached_ptr_size: usize,
cached_fn_trait_func: Option<FunctionId>,
cached_fn_mut_trait_func: Option<FunctionId>,
cached_fn_once_trait_func: Option<FunctionId>,
crate_id: CrateId, crate_id: CrateId,
// FIXME: This is a workaround, see the comment on `interpret_mir` // FIXME: This is a workaround, see the comment on `interpret_mir`
assert_placeholder_ty_is_unused: bool, assert_placeholder_ty_is_unused: bool,
@ -477,6 +501,10 @@ impl DropFlags {
} }
self.need_drop.remove(p) self.need_drop.remove(p)
} }
fn clear(&mut self) {
self.need_drop.clear();
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -550,6 +578,26 @@ impl Evaluator<'_> {
execution_limit: EXECUTION_LIMIT, execution_limit: EXECUTION_LIMIT,
memory_limit: 1000_000_000, // 2GB, 1GB for stack and 1GB for heap memory_limit: 1000_000_000, // 2GB, 1GB for stack and 1GB for heap
layout_cache: RefCell::new(HashMap::default()), layout_cache: RefCell::new(HashMap::default()),
projected_ty_cache: RefCell::new(HashMap::default()),
not_special_fn_cache: RefCell::new(HashSet::default()),
mir_or_dyn_index_cache: RefCell::new(HashMap::default()),
unused_locals_store: RefCell::new(HashMap::default()),
cached_ptr_size: match db.target_data_layout(crate_id) {
Some(it) => it.pointer_size.bytes_usize(),
None => 8,
},
cached_fn_trait_func: db
.lang_item(crate_id, LangItem::Fn)
.and_then(|x| x.as_trait())
.and_then(|x| db.trait_data(x).method_by_name(&name![call])),
cached_fn_mut_trait_func: db
.lang_item(crate_id, LangItem::FnMut)
.and_then(|x| x.as_trait())
.and_then(|x| db.trait_data(x).method_by_name(&name![call_mut])),
cached_fn_once_trait_func: db
.lang_item(crate_id, LangItem::FnOnce)
.and_then(|x| x.as_trait())
.and_then(|x| db.trait_data(x).method_by_name(&name![call_once])),
} }
} }
@ -570,10 +618,34 @@ impl Evaluator<'_> {
} }
fn ptr_size(&self) -> usize { fn ptr_size(&self) -> usize {
match self.db.target_data_layout(self.crate_id) { self.cached_ptr_size
Some(it) => it.pointer_size.bytes_usize(), }
None => 8,
fn projected_ty(&self, ty: Ty, proj: PlaceElem) -> Ty {
let pair = (ty, proj);
if let Some(r) = self.projected_ty_cache.borrow().get(&pair) {
return r.clone();
} }
let (ty, proj) = pair;
let r = proj.projected_ty(
ty.clone(),
self.db,
|c, subst, f| {
let (def, _) = self.db.lookup_intern_closure(c.into());
let infer = self.db.infer(def);
let (captures, _) = infer.closure_info(&c);
let parent_subst = ClosureSubst(subst).parent_subst();
captures
.get(f)
.expect("broken closure field")
.ty
.clone()
.substitute(Interner, parent_subst)
},
self.crate_id,
);
self.projected_ty_cache.borrow_mut().insert((ty, proj), r.clone());
r
} }
fn place_addr_and_ty_and_metadata<'a>( fn place_addr_and_ty_and_metadata<'a>(
@ -586,23 +658,7 @@ impl Evaluator<'_> {
let mut metadata: Option<IntervalOrOwned> = None; // locals are always sized let mut metadata: Option<IntervalOrOwned> = None; // locals are always sized
for proj in &*p.projection { for proj in &*p.projection {
let prev_ty = ty.clone(); let prev_ty = ty.clone();
ty = proj.projected_ty( ty = self.projected_ty(ty, proj.clone());
ty,
self.db,
|c, subst, f| {
let (def, _) = self.db.lookup_intern_closure(c.into());
let infer = self.db.infer(def);
let (captures, _) = infer.closure_info(&c);
let parent_subst = ClosureSubst(subst).parent_subst();
captures
.get(f)
.expect("broken closure field")
.ty
.clone()
.substitute(Interner, parent_subst)
},
self.crate_id,
);
match proj { match proj {
ProjectionElem::Deref => { ProjectionElem::Deref => {
metadata = if self.size_align_of(&ty, locals)?.is_none() { metadata = if self.size_align_of(&ty, locals)?.is_none() {
@ -756,18 +812,18 @@ impl Evaluator<'_> {
return Err(MirEvalError::StackOverflow); return Err(MirEvalError::StackOverflow);
} }
let mut current_block_idx = body.start_block; let mut current_block_idx = body.start_block;
let (mut locals, prev_stack_ptr) = self.create_locals_for_body(body.clone(), None)?; let (mut locals, prev_stack_ptr) = self.create_locals_for_body(&body, None)?;
self.fill_locals_for_body(&body, &mut locals, args)?; self.fill_locals_for_body(&body, &mut locals, args)?;
let prev_code_stack = mem::take(&mut self.code_stack); let prev_code_stack = mem::take(&mut self.code_stack);
let span = (MirSpan::Unknown, body.owner); let span = (MirSpan::Unknown, body.owner);
self.code_stack.push(StackFrame { body, locals, destination: None, prev_stack_ptr, span }); self.code_stack.push(StackFrame { locals, destination: None, prev_stack_ptr, span });
'stack: loop { 'stack: loop {
let Some(mut my_stack_frame) = self.code_stack.pop() else { let Some(mut my_stack_frame) = self.code_stack.pop() else {
not_supported!("missing stack frame"); not_supported!("missing stack frame");
}; };
let e = (|| { let e = (|| {
let mut locals = &mut my_stack_frame.locals; let mut locals = &mut my_stack_frame.locals;
let body = &*my_stack_frame.body; let body = locals.body.clone();
loop { loop {
let current_block = &body.basic_blocks[current_block_idx]; let current_block = &body.basic_blocks[current_block_idx];
if let Some(it) = self.execution_limit.checked_sub(1) { if let Some(it) = self.execution_limit.checked_sub(1) {
@ -836,7 +892,7 @@ impl Evaluator<'_> {
locals.drop_flags.add_place(destination.clone()); locals.drop_flags.add_place(destination.clone());
if let Some(stack_frame) = stack_frame { if let Some(stack_frame) = stack_frame {
self.code_stack.push(my_stack_frame); self.code_stack.push(my_stack_frame);
current_block_idx = stack_frame.body.start_block; current_block_idx = stack_frame.locals.body.start_block;
self.code_stack.push(stack_frame); self.code_stack.push(stack_frame);
return Ok(None); return Ok(None);
} else { } else {
@ -877,18 +933,24 @@ impl Evaluator<'_> {
let my_code_stack = mem::replace(&mut self.code_stack, prev_code_stack); let my_code_stack = mem::replace(&mut self.code_stack, prev_code_stack);
let mut error_stack = vec![]; let mut error_stack = vec![];
for frame in my_code_stack.into_iter().rev() { for frame in my_code_stack.into_iter().rev() {
if let DefWithBodyId::FunctionId(f) = frame.body.owner { if let DefWithBodyId::FunctionId(f) = frame.locals.body.owner {
error_stack.push((Either::Left(f), frame.span.0, frame.span.1)); error_stack.push((Either::Left(f), frame.span.0, frame.span.1));
} }
} }
return Err(MirEvalError::InFunction(Box::new(e), error_stack)); return Err(MirEvalError::InFunction(Box::new(e), error_stack));
} }
}; };
let return_interval = my_stack_frame.locals.ptr[return_slot()];
self.unused_locals_store
.borrow_mut()
.entry(my_stack_frame.locals.body.owner)
.or_default()
.push(my_stack_frame.locals);
match my_stack_frame.destination { match my_stack_frame.destination {
None => { None => {
self.code_stack = prev_code_stack; self.code_stack = prev_code_stack;
self.stack_depth_limit += 1; self.stack_depth_limit += 1;
return Ok(my_stack_frame.locals.ptr[return_slot()].get(self)?.to_vec()); return Ok(return_interval.get(self)?.to_vec());
} }
Some(bb) => { Some(bb) => {
// We don't support const promotion, so we can't truncate the stack yet. // We don't support const promotion, so we can't truncate the stack yet.
@ -926,39 +988,45 @@ impl Evaluator<'_> {
fn create_locals_for_body( fn create_locals_for_body(
&mut self, &mut self,
body: Arc<MirBody>, body: &Arc<MirBody>,
destination: Option<Interval>, destination: Option<Interval>,
) -> Result<(Locals, usize)> { ) -> Result<(Locals, usize)> {
let mut locals = let mut locals =
Locals { ptr: ArenaMap::new(), body: body.clone(), drop_flags: DropFlags::default() }; match self.unused_locals_store.borrow_mut().entry(body.owner).or_default().pop() {
let (locals_ptr, stack_size) = { None => Locals {
ptr: ArenaMap::new(),
body: body.clone(),
drop_flags: DropFlags::default(),
},
Some(mut l) => {
l.drop_flags.clear();
l.body = body.clone();
l
}
};
let stack_size = {
let mut stack_ptr = self.stack.len(); let mut stack_ptr = self.stack.len();
let addr = body for (id, it) in body.locals.iter() {
.locals if id == return_slot() {
.iter() if let Some(destination) = destination {
.map(|(id, it)| { locals.ptr.insert(id, destination);
if id == return_slot() { continue;
if let Some(destination) = destination {
return Ok((id, destination));
}
} }
let (size, align) = self.size_align_of_sized( }
&it.ty, let (size, align) = self.size_align_of_sized(
&locals, &it.ty,
"no unsized local in extending stack", &locals,
)?; "no unsized local in extending stack",
while stack_ptr % align != 0 { )?;
stack_ptr += 1; while stack_ptr % align != 0 {
} stack_ptr += 1;
let my_ptr = stack_ptr; }
stack_ptr += size; let my_ptr = stack_ptr;
Ok((id, Interval { addr: Stack(my_ptr), size })) stack_ptr += size;
}) locals.ptr.insert(id, Interval { addr: Stack(my_ptr), size });
.collect::<Result<ArenaMap<LocalId, _>>>()?; }
let stack_size = stack_ptr - self.stack.len(); stack_ptr - self.stack.len()
(addr, stack_size)
}; };
locals.ptr = locals_ptr;
let prev_stack_pointer = self.stack.len(); let prev_stack_pointer = self.stack.len();
if stack_size > self.memory_limit { if stack_size > self.memory_limit {
return Err(MirEvalError::Panic(format!( return Err(MirEvalError::Panic(format!(
@ -1693,6 +1761,11 @@ impl Evaluator<'_> {
} }
fn size_align_of(&self, ty: &Ty, locals: &Locals) -> Result<Option<(usize, usize)>> { fn size_align_of(&self, ty: &Ty, locals: &Locals) -> Result<Option<(usize, usize)>> {
if let Some(layout) = self.layout_cache.borrow().get(ty) {
return Ok(layout
.is_sized()
.then(|| (layout.size.bytes_usize(), layout.align.abi.bytes() as usize)));
}
if let DefWithBodyId::VariantId(f) = locals.body.owner { if let DefWithBodyId::VariantId(f) = locals.body.owner {
if let Some((adt, _)) = ty.as_adt() { if let Some((adt, _)) = ty.as_adt() {
if AdtId::from(f.parent) == adt { if AdtId::from(f.parent) == adt {
@ -1753,16 +1826,15 @@ impl Evaluator<'_> {
} }
fn detect_fn_trait(&self, def: FunctionId) -> Option<FnTrait> { fn detect_fn_trait(&self, def: FunctionId) -> Option<FnTrait> {
use LangItem::*; let def = Some(def);
let ItemContainerId::TraitId(parent) = self.db.lookup_intern_function(def).container else { if def == self.cached_fn_trait_func {
return None; Some(FnTrait::Fn)
}; } else if def == self.cached_fn_mut_trait_func {
let l = self.db.lang_attr(parent.into())?; Some(FnTrait::FnMut)
match l { } else if def == self.cached_fn_once_trait_func {
FnOnce => Some(FnTrait::FnOnce), Some(FnTrait::FnOnce)
FnMut => Some(FnTrait::FnMut), } else {
Fn => Some(FnTrait::Fn), None
_ => None,
} }
} }
@ -2105,6 +2177,40 @@ impl Evaluator<'_> {
} }
} }
fn get_mir_or_dyn_index(
&self,
def: FunctionId,
generic_args: Substitution,
locals: &Locals,
span: MirSpan,
) -> Result<MirOrDynIndex> {
let pair = (def, generic_args);
if let Some(r) = self.mir_or_dyn_index_cache.borrow().get(&pair) {
return Ok(r.clone());
}
let (def, generic_args) = pair;
let r = if let Some(self_ty_idx) =
is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone())
{
MirOrDynIndex::Dyn(self_ty_idx)
} else {
let (imp, generic_args) =
self.db.lookup_impl_method(self.trait_env.clone(), def, generic_args.clone());
let mir_body = self
.db
.monomorphized_mir_body(imp.into(), generic_args, self.trait_env.clone())
.map_err(|e| {
MirEvalError::InFunction(
Box::new(MirEvalError::MirLowerError(imp, e)),
vec![(Either::Left(imp), span, locals.body.owner)],
)
})?;
MirOrDynIndex::Mir(mir_body)
};
self.mir_or_dyn_index_cache.borrow_mut().insert((def, generic_args), r.clone());
Ok(r)
}
fn exec_fn_with_args( fn exec_fn_with_args(
&mut self, &mut self,
def: FunctionId, def: FunctionId,
@ -2126,93 +2232,76 @@ impl Evaluator<'_> {
return Ok(None); return Ok(None);
} }
let arg_bytes = args.iter().map(|it| IntervalOrOwned::Borrowed(it.interval)); let arg_bytes = args.iter().map(|it| IntervalOrOwned::Borrowed(it.interval));
if let Some(self_ty_idx) = match self.get_mir_or_dyn_index(def, generic_args.clone(), locals, span)? {
is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone()) MirOrDynIndex::Dyn(self_ty_idx) => {
{ // In the layout of current possible receiver, which at the moment of writing this code is one of
// In the layout of current possible receiver, which at the moment of writing this code is one of // `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers,
// `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers, // the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on
// the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on // the type.
// the type. let first_arg = arg_bytes.clone().next().unwrap();
let first_arg = arg_bytes.clone().next().unwrap(); let first_arg = first_arg.get(self)?;
let first_arg = first_arg.get(self)?; let ty = self
let ty = .vtable_map
self.vtable_map.ty_of_bytes(&first_arg[self.ptr_size()..self.ptr_size() * 2])?; .ty_of_bytes(&first_arg[self.ptr_size()..self.ptr_size() * 2])?;
let mut args_for_target = args.to_vec(); let mut args_for_target = args.to_vec();
args_for_target[0] = IntervalAndTy { args_for_target[0] = IntervalAndTy {
interval: args_for_target[0].interval.slice(0..self.ptr_size()), interval: args_for_target[0].interval.slice(0..self.ptr_size()),
ty: ty.clone(), ty: ty.clone(),
}; };
let ty = GenericArgData::Ty(ty.clone()).intern(Interner); let ty = GenericArgData::Ty(ty.clone()).intern(Interner);
let generics_for_target = Substitution::from_iter( let generics_for_target = Substitution::from_iter(
Interner, Interner,
generic_args.iter(Interner).enumerate().map(|(i, it)| { generic_args.iter(Interner).enumerate().map(|(i, it)| {
if i == self_ty_idx { if i == self_ty_idx {
&ty &ty
} else { } else {
it it
} }
}), }),
); );
return self.exec_fn_with_args( return self.exec_fn_with_args(
def, def,
&args_for_target, &args_for_target,
generics_for_target, generics_for_target,
locals,
destination,
target_bb,
span,
);
}
MirOrDynIndex::Mir(body) => self.exec_looked_up_function(
body,
locals, locals,
def,
arg_bytes,
span,
destination, destination,
target_bb, target_bb,
span, ),
);
} }
let (imp, generic_args) =
self.db.lookup_impl_method(self.trait_env.clone(), def, generic_args);
self.exec_looked_up_function(
generic_args,
locals,
imp,
arg_bytes,
span,
destination,
target_bb,
)
} }
fn exec_looked_up_function( fn exec_looked_up_function(
&mut self, &mut self,
generic_args: Substitution, mir_body: Arc<MirBody>,
locals: &Locals, locals: &Locals,
imp: FunctionId, def: FunctionId,
arg_bytes: impl Iterator<Item = IntervalOrOwned>, arg_bytes: impl Iterator<Item = IntervalOrOwned>,
span: MirSpan, span: MirSpan,
destination: Interval, destination: Interval,
target_bb: Option<BasicBlockId>, target_bb: Option<BasicBlockId>,
) -> Result<Option<StackFrame>> { ) -> Result<Option<StackFrame>> {
let def = imp.into();
let mir_body = self
.db
.monomorphized_mir_body(def, generic_args, self.trait_env.clone())
.map_err(|e| {
MirEvalError::InFunction(
Box::new(MirEvalError::MirLowerError(imp, e)),
vec![(Either::Left(imp), span, locals.body.owner)],
)
})?;
Ok(if let Some(target_bb) = target_bb { Ok(if let Some(target_bb) = target_bb {
let (mut locals, prev_stack_ptr) = let (mut locals, prev_stack_ptr) =
self.create_locals_for_body(mir_body.clone(), Some(destination))?; self.create_locals_for_body(&mir_body, Some(destination))?;
self.fill_locals_for_body(&mir_body, &mut locals, arg_bytes.into_iter())?; self.fill_locals_for_body(&mir_body, &mut locals, arg_bytes.into_iter())?;
let span = (span, locals.body.owner); let span = (span, locals.body.owner);
Some(StackFrame { Some(StackFrame { locals, destination: Some(target_bb), prev_stack_ptr, span })
body: mir_body,
locals,
destination: Some(target_bb),
prev_stack_ptr,
span,
})
} else { } else {
let result = self.interpret_mir(mir_body, arg_bytes).map_err(|e| { let result = self.interpret_mir(mir_body, arg_bytes).map_err(|e| {
MirEvalError::InFunction( MirEvalError::InFunction(
Box::new(e), Box::new(e),
vec![(Either::Left(imp), span, locals.body.owner)], vec![(Either::Left(def), span, locals.body.owner)],
) )
})?; })?;
destination.write_from_bytes(self, &result)?; destination.write_from_bytes(self, &result)?;
@ -2384,16 +2473,15 @@ impl Evaluator<'_> {
// we can ignore drop in them. // we can ignore drop in them.
return Ok(()); return Ok(());
}; };
let (impl_drop_candidate, subst) = self.db.lookup_impl_method(
self.trait_env.clone(), let generic_args = Substitution::from1(Interner, ty.clone());
drop_fn, if let Ok(MirOrDynIndex::Mir(body)) =
Substitution::from1(Interner, ty.clone()), self.get_mir_or_dyn_index(drop_fn, generic_args, locals, span)
); {
if impl_drop_candidate != drop_fn {
self.exec_looked_up_function( self.exec_looked_up_function(
subst, body,
locals, locals,
impl_drop_candidate, drop_fn,
[IntervalOrOwned::Owned(addr.to_bytes())].into_iter(), [IntervalOrOwned::Owned(addr.to_bytes())].into_iter(),
span, span,
Interval { addr: Address::Invalid(0), size: 0 }, Interval { addr: Address::Invalid(0), size: 0 },

View file

@ -36,6 +36,9 @@ impl Evaluator<'_> {
destination: Interval, destination: Interval,
span: MirSpan, span: MirSpan,
) -> Result<bool> { ) -> Result<bool> {
if self.not_special_fn_cache.borrow().contains(&def) {
return Ok(false);
}
let function_data = self.db.function_data(def); let function_data = self.db.function_data(def);
let is_intrinsic = match &function_data.abi { let is_intrinsic = match &function_data.abi {
Some(abi) => *abi == Interned::new_str("rust-intrinsic"), Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
@ -137,8 +140,11 @@ impl Evaluator<'_> {
self.exec_clone(def, args, self_ty.clone(), locals, destination, span)?; self.exec_clone(def, args, self_ty.clone(), locals, destination, span)?;
return Ok(true); return Ok(true);
} }
// Return early to prevent caching clone as non special fn.
return Ok(false);
} }
} }
self.not_special_fn_cache.borrow_mut().insert(def);
Ok(false) Ok(false)
} }