Rollup merge of #117095 - klinvill:smir-fn-arg-count, r=oli-obk
Add way to differentiate argument locals from other locals in Stable MIR This PR resolves rust-lang/project-stable-mir#47 which request a way to differentiate argument locals in a SMIR `Body` from other locals. Specifically, this PR exposes the `arg_count` field from the MIR `Body`. However, I'm opening this as a draft PR because I think there are a few outstanding questions on how this information should be exposed and described. Namely: - Is exposing `arg_count` the best way to surface this information to SMIR users? Would it be better to leave `arg_count` as a private field and add public methods (e.g. `fn arguments(&self) -> Iter<'_, LocalDecls>`) that may use the underlying `arg_count` info from the MIR body, but expose this information to users in a more convenient form? Or is it best to stick close to the current MIR convention? - If the answer to the above point is to stick with the current MIR convention (`arg_count`), is it reasonable to also commit to sticking to the current MIR convention that the first local is always the return local, while the next `arg_count` locals are always the (in-order) argument locals? - Should `Body` in SMIR only represent function bodies (as implied by the comment I added)? That seems to be the current case in MIR, but should this restriction always be the case for SMIR? r? `@celinval` r? `@oli-obk`
This commit is contained in:
commit
b66c6e719f
5 changed files with 102 additions and 22 deletions
|
@ -287,9 +287,8 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> {
|
||||||
type T = stable_mir::mir::Body;
|
type T = stable_mir::mir::Body;
|
||||||
|
|
||||||
fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
|
fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
|
||||||
stable_mir::mir::Body {
|
stable_mir::mir::Body::new(
|
||||||
blocks: self
|
self.basic_blocks
|
||||||
.basic_blocks
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|block| stable_mir::mir::BasicBlock {
|
.map(|block| stable_mir::mir::BasicBlock {
|
||||||
terminator: block.terminator().stable(tables),
|
terminator: block.terminator().stable(tables),
|
||||||
|
@ -300,15 +299,15 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> {
|
||||||
.collect(),
|
.collect(),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
locals: self
|
self.local_decls
|
||||||
.local_decls
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|decl| stable_mir::mir::LocalDecl {
|
.map(|decl| stable_mir::mir::LocalDecl {
|
||||||
ty: decl.ty.stable(tables),
|
ty: decl.ty.stable(tables),
|
||||||
span: decl.source_info.span.stable(tables),
|
span: decl.source_info.span.stable(tables),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
self.arg_count,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,60 @@ use crate::ty::{AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability
|
||||||
use crate::Opaque;
|
use crate::Opaque;
|
||||||
use crate::{ty::Ty, Span};
|
use crate::{ty::Ty, Span};
|
||||||
|
|
||||||
|
/// The SMIR representation of a single function.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Body {
|
pub struct Body {
|
||||||
pub blocks: Vec<BasicBlock>,
|
pub blocks: Vec<BasicBlock>,
|
||||||
pub locals: LocalDecls,
|
|
||||||
|
// Declarations of locals within the function.
|
||||||
|
//
|
||||||
|
// The first local is the return value pointer, followed by `arg_count`
|
||||||
|
// locals for the function arguments, followed by any user-declared
|
||||||
|
// variables and temporaries.
|
||||||
|
locals: LocalDecls,
|
||||||
|
|
||||||
|
// The number of arguments this function takes.
|
||||||
|
arg_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Body {
|
||||||
|
/// Constructs a `Body`.
|
||||||
|
///
|
||||||
|
/// A constructor is required to build a `Body` from outside the crate
|
||||||
|
/// because the `arg_count` and `locals` fields are private.
|
||||||
|
pub fn new(blocks: Vec<BasicBlock>, locals: LocalDecls, arg_count: usize) -> Self {
|
||||||
|
// If locals doesn't contain enough entries, it can lead to panics in
|
||||||
|
// `ret_local`, `arg_locals`, and `inner_locals`.
|
||||||
|
assert!(
|
||||||
|
locals.len() > arg_count,
|
||||||
|
"A Body must contain at least a local for the return value and each of the function's arguments"
|
||||||
|
);
|
||||||
|
Self { blocks, locals, arg_count }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return local that holds this function's return value.
|
||||||
|
pub fn ret_local(&self) -> &LocalDecl {
|
||||||
|
&self.locals[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Locals in `self` that correspond to this function's arguments.
|
||||||
|
pub fn arg_locals(&self) -> &[LocalDecl] {
|
||||||
|
&self.locals[1..][..self.arg_count]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inner locals for this function. These are the locals that are
|
||||||
|
/// neither the return local nor the argument locals.
|
||||||
|
pub fn inner_locals(&self) -> &[LocalDecl] {
|
||||||
|
&self.locals[self.arg_count + 1..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function to get all the locals in this function.
|
||||||
|
///
|
||||||
|
/// Locals are typically accessed via the more specific methods `ret_local`,
|
||||||
|
/// `arg_locals`, and `inner_locals`.
|
||||||
|
pub fn locals(&self) -> &[LocalDecl] {
|
||||||
|
&self.locals
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalDecls = Vec<LocalDecl>;
|
type LocalDecls = Vec<LocalDecl>;
|
||||||
|
@ -467,7 +517,7 @@ pub enum NullOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Operand {
|
impl Operand {
|
||||||
pub fn ty(&self, locals: &LocalDecls) -> Ty {
|
pub fn ty(&self, locals: &[LocalDecl]) -> Ty {
|
||||||
match self {
|
match self {
|
||||||
Operand::Copy(place) | Operand::Move(place) => place.ty(locals),
|
Operand::Copy(place) | Operand::Move(place) => place.ty(locals),
|
||||||
Operand::Constant(c) => c.ty(),
|
Operand::Constant(c) => c.ty(),
|
||||||
|
@ -482,7 +532,7 @@ impl Constant {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Place {
|
impl Place {
|
||||||
pub fn ty(&self, locals: &LocalDecls) -> Ty {
|
pub fn ty(&self, locals: &[LocalDecl]) -> Ty {
|
||||||
let _start_ty = locals[self.local].ty;
|
let _start_ty = locals[self.local].ty;
|
||||||
todo!("Implement projection")
|
todo!("Implement projection")
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ fn test_body(body: mir::Body) {
|
||||||
for term in body.blocks.iter().map(|bb| &bb.terminator) {
|
for term in body.blocks.iter().map(|bb| &bb.terminator) {
|
||||||
match &term.kind {
|
match &term.kind {
|
||||||
Call { func, .. } => {
|
Call { func, .. } => {
|
||||||
let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() };
|
let TyKind::RigidTy(ty) = func.ty(body.locals()).kind() else { unreachable!() };
|
||||||
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
|
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
|
||||||
let result = Instance::resolve(def, &args);
|
let result = Instance::resolve(def, &args);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
|
@ -47,7 +47,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
|
|
||||||
let bar = get_item(&items, (DefKind::Fn, "bar")).unwrap();
|
let bar = get_item(&items, (DefKind::Fn, "bar")).unwrap();
|
||||||
let body = bar.body();
|
let body = bar.body();
|
||||||
assert_eq!(body.locals.len(), 2);
|
assert_eq!(body.locals().len(), 2);
|
||||||
assert_eq!(body.blocks.len(), 1);
|
assert_eq!(body.blocks.len(), 1);
|
||||||
let block = &body.blocks[0];
|
let block = &body.blocks[0];
|
||||||
assert_eq!(block.statements.len(), 1);
|
assert_eq!(block.statements.len(), 1);
|
||||||
|
@ -62,7 +62,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
|
|
||||||
let foo_bar = get_item(&items, (DefKind::Fn, "foo_bar")).unwrap();
|
let foo_bar = get_item(&items, (DefKind::Fn, "foo_bar")).unwrap();
|
||||||
let body = foo_bar.body();
|
let body = foo_bar.body();
|
||||||
assert_eq!(body.locals.len(), 5);
|
assert_eq!(body.locals().len(), 5);
|
||||||
assert_eq!(body.blocks.len(), 4);
|
assert_eq!(body.blocks.len(), 4);
|
||||||
let block = &body.blocks[0];
|
let block = &body.blocks[0];
|
||||||
match &block.terminator.kind {
|
match &block.terminator.kind {
|
||||||
|
@ -72,29 +72,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
|
|
||||||
let types = get_item(&items, (DefKind::Fn, "types")).unwrap();
|
let types = get_item(&items, (DefKind::Fn, "types")).unwrap();
|
||||||
let body = types.body();
|
let body = types.body();
|
||||||
assert_eq!(body.locals.len(), 6);
|
assert_eq!(body.locals().len(), 6);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[0].ty.kind(),
|
body.locals()[0].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[1].ty.kind(),
|
body.locals()[1].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[2].ty.kind(),
|
body.locals()[2].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char)
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char)
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[3].ty.kind(),
|
body.locals()[3].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32))
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32))
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[4].ty.kind(),
|
body.locals()[4].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64))
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64))
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
body.locals[5].ty.kind(),
|
body.locals()[5].ty.kind(),
|
||||||
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Float(
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Float(
|
||||||
stable_mir::ty::FloatTy::F64
|
stable_mir::ty::FloatTy::F64
|
||||||
))
|
))
|
||||||
|
@ -123,10 +123,10 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
for block in instance.body().blocks {
|
for block in instance.body().blocks {
|
||||||
match &block.terminator.kind {
|
match &block.terminator.kind {
|
||||||
stable_mir::mir::TerminatorKind::Call { func, .. } => {
|
stable_mir::mir::TerminatorKind::Call { func, .. } => {
|
||||||
let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() };
|
let TyKind::RigidTy(ty) = func.ty(&body.locals()).kind() else { unreachable!() };
|
||||||
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
|
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
|
||||||
let next_func = Instance::resolve(def, &args).unwrap();
|
let next_func = Instance::resolve(def, &args).unwrap();
|
||||||
match next_func.body().locals[1].ty.kind() {
|
match next_func.body().locals()[1].ty.kind() {
|
||||||
TyKind::RigidTy(RigidTy::Uint(_)) | TyKind::RigidTy(RigidTy::Tuple(_)) => {}
|
TyKind::RigidTy(RigidTy::Uint(_)) | TyKind::RigidTy(RigidTy::Tuple(_)) => {}
|
||||||
other => panic!("{other:?}"),
|
other => panic!("{other:?}"),
|
||||||
}
|
}
|
||||||
|
@ -140,6 +140,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
// Ensure we don't panic trying to get the body of a constant.
|
// Ensure we don't panic trying to get the body of a constant.
|
||||||
foo_const.body();
|
foo_const.body();
|
||||||
|
|
||||||
|
let locals_fn = get_item(&items, (DefKind::Fn, "locals")).unwrap();
|
||||||
|
let body = locals_fn.body();
|
||||||
|
assert_eq!(body.locals().len(), 4);
|
||||||
|
assert_matches!(
|
||||||
|
body.ret_local().ty.kind(),
|
||||||
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char)
|
||||||
|
);
|
||||||
|
assert_eq!(body.arg_locals().len(), 2);
|
||||||
|
assert_matches!(
|
||||||
|
body.arg_locals()[0].ty.kind(),
|
||||||
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32))
|
||||||
|
);
|
||||||
|
assert_matches!(
|
||||||
|
body.arg_locals()[1].ty.kind(),
|
||||||
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64))
|
||||||
|
);
|
||||||
|
assert_eq!(body.inner_locals().len(), 1);
|
||||||
|
// If conditions have an extra inner local to hold their results
|
||||||
|
assert_matches!(
|
||||||
|
body.inner_locals()[0].ty.kind(),
|
||||||
|
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
|
||||||
|
);
|
||||||
|
|
||||||
ControlFlow::Continue(())
|
ControlFlow::Continue(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,6 +234,14 @@ fn generate_input(path: &str) -> std::io::Result<()> {
|
||||||
|
|
||||||
pub fn assert(x: i32) -> i32 {{
|
pub fn assert(x: i32) -> i32 {{
|
||||||
x + 1
|
x + 1
|
||||||
|
}}
|
||||||
|
|
||||||
|
pub fn locals(a: i32, _: u64) -> char {{
|
||||||
|
if a > 5 {{
|
||||||
|
'a'
|
||||||
|
}} else {{
|
||||||
|
'b'
|
||||||
|
}}
|
||||||
}}"#
|
}}"#
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -29,7 +29,7 @@ const CRATE_NAME: &str = "input";
|
||||||
fn test_translation(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
fn test_translation(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
|
||||||
let main_fn = stable_mir::entry_fn().unwrap();
|
let main_fn = stable_mir::entry_fn().unwrap();
|
||||||
let body = main_fn.body();
|
let body = main_fn.body();
|
||||||
let orig_ty = body.locals[0].ty;
|
let orig_ty = body.locals()[0].ty;
|
||||||
let rustc_ty = rustc_internal::internal(&orig_ty);
|
let rustc_ty = rustc_internal::internal(&orig_ty);
|
||||||
assert!(rustc_ty.is_unit());
|
assert!(rustc_ty.is_unit());
|
||||||
ControlFlow::Continue(())
|
ControlFlow::Continue(())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue