1
Fork 0

Desugars contract into the internal AST extensions

Check ensures on early return due to Try / Yeet

Expand these two expressions to include a call to contract checking
This commit is contained in:
Felix S. Klock II 2024-12-03 02:52:29 +00:00 committed by Celina G. Val
parent 38eff16d0a
commit ae7eff0be5
12 changed files with 457 additions and 88 deletions

View file

@ -314,21 +314,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label)) hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
} }
ExprKind::Ret(e) => { ExprKind::Ret(e) => {
let mut e = e.as_ref().map(|x| self.lower_expr(x)); let expr = e.as_ref().map(|x| self.lower_expr(x));
if let Some(Some((span, fresh_ident))) = self self.checked_return(expr)
.contract
.as_ref()
.map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2);
let args = if let Some(e) = e {
std::slice::from_ref(e)
} else {
std::slice::from_ref(self.expr_unit(span))
};
e = Some(self.expr_call(span, checker_fn, args));
}
hir::ExprKind::Ret(e)
} }
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()), ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
ExprKind::Become(sub_expr) => { ExprKind::Become(sub_expr) => {
@ -395,6 +382,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
}) })
} }
/// Create an `ExprKind::Ret` that is preceded by a call to check contract ensures clause.
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
let checked_ret = if let Some(Some((span, fresh_ident))) =
self.contract.as_ref().map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(span));
Some(self.inject_ensures_check(expr, span, fresh_ident.0, fresh_ident.2))
} else {
opt_expr
};
hir::ExprKind::Ret(checked_ret)
}
/// Wraps an expression with a call to the ensures check before it gets returned.
pub(crate) fn inject_ensures_check(
&mut self,
expr: &'hir hir::Expr<'hir>,
span: Span,
check_ident: Ident,
check_hir_id: HirId,
) -> &'hir hir::Expr<'hir> {
let checker_fn = self.expr_ident(span, check_ident, check_hir_id);
let span = self.mark_span_with_reason(DesugaringKind::Contract, span, None);
self.expr_call(span, checker_fn, std::slice::from_ref(expr))
}
pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock { pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
self.with_new_scopes(c.value.span, |this| { self.with_new_scopes(c.value.span, |this| {
let def_id = this.local_def_id(c.id); let def_id = this.local_def_id(c.id);
@ -1983,7 +1996,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
), ),
)) ))
} else { } else {
self.arena.alloc(self.expr(try_span, hir::ExprKind::Ret(Some(from_residual_expr)))) let ret_expr = self.checked_return(Some(from_residual_expr));
self.arena.alloc(self.expr(try_span, ret_expr))
}; };
self.lower_attrs(ret_expr.hir_id, &attrs); self.lower_attrs(ret_expr.hir_id, &attrs);
@ -2032,7 +2046,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let target_id = Ok(catch_id); let target_id = Ok(catch_id);
hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr)) hir::ExprKind::Break(hir::Destination { label: None, target_id }, Some(from_yeet_expr))
} else { } else {
hir::ExprKind::Ret(Some(from_yeet_expr)) self.checked_return(Some(from_yeet_expr))
} }
} }

View file

@ -215,7 +215,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
if let Some(contract) = contract { if let Some(contract) = contract {
let requires = contract.requires.clone(); let requires = contract.requires.clone();
let ensures = contract.ensures.clone(); let ensures = contract.ensures.clone();
let ensures = if let Some(ens) = ensures { let ensures = ensures.map(|ens| {
// FIXME: this needs to be a fresh (or illegal) identifier to prevent // FIXME: this needs to be a fresh (or illegal) identifier to prevent
// accidental capture of a parameter or global variable. // accidental capture of a parameter or global variable.
let checker_ident: Ident = let checker_ident: Ident =
@ -226,13 +226,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::BindingMode::NONE, hir::BindingMode::NONE,
); );
Some(crate::FnContractLoweringEnsures { crate::FnContractLoweringEnsures {
expr: ens, expr: ens,
fresh_ident: (checker_ident, checker_pat, checker_hir_id), fresh_ident: (checker_ident, checker_pat, checker_hir_id),
}) }
} else { });
None
};
// Note: `with_new_scopes` will reinstall the outer // Note: `with_new_scopes` will reinstall the outer
// item's contract (if any) after its callback finishes. // item's contract (if any) after its callback finishes.
@ -1095,8 +1093,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
// { body } // { body }
// ==> // ==>
// { rustc_contract_requires(PRECOND); { body } } // { contract_requires(PRECOND); { body } }
let result: hir::Expr<'hir> = if let Some(contract) = opt_contract { let Some(contract) = opt_contract else { return (params, result) };
let result_ref = this.arena.alloc(result);
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| { let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
this.expr(contract.span, hir::ExprKind::Tup(&[])) this.expr(contract.span, hir::ExprKind::Tup(&[]))
}; };
@ -1131,37 +1130,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
this.arena.alloc(checker_binding_pat), this.arena.alloc(checker_binding_pat),
hir::LocalSource::Contract, hir::LocalSource::Contract,
), ),
{ this.inject_ensures_check(result_ref, ens.span, fresh_ident.0, fresh_ident.2),
let checker_fn =
this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2);
let span = this.mark_span_with_reason(
DesugaringKind::Contract,
ens.span,
None,
);
this.expr_call_mut(
span,
checker_fn,
std::slice::from_ref(this.arena.alloc(result)),
)
},
) )
} else { } else {
let u = lit_unit(this); let u = lit_unit(this);
(this.stmt_expr(contract.span, u), result) (this.stmt_expr(contract.span, u), &*result_ref)
}; };
let block = this.block_all( let block = this.block_all(
contract.span, contract.span,
arena_vec![this; precond, postcond_checker], arena_vec![this; precond, postcond_checker],
Some(this.arena.alloc(result)), Some(result),
); );
this.expr_block(block) (params, this.expr_block(block))
} else {
result
};
(params, result)
}) })
} }

View file

@ -0,0 +1,172 @@
#![allow(unused_imports, unused_variables)]
use rustc_ast::token;
use rustc_ast::tokenstream::{DelimSpacing, DelimSpan, Spacing, TokenStream, TokenTree};
use rustc_errors::ErrorGuaranteed;
use rustc_expand::base::{AttrProcMacro, ExtCtxt};
use rustc_span::Span;
use rustc_span::symbol::{Ident, Symbol, kw, sym};
pub(crate) struct ExpandRequires;
pub(crate) struct ExpandEnsures;
impl AttrProcMacro for ExpandRequires {
fn expand<'cx>(
&self,
ecx: &'cx mut ExtCtxt<'_>,
span: Span,
annotation: TokenStream,
annotated: TokenStream,
) -> Result<TokenStream, ErrorGuaranteed> {
expand_requires_tts(ecx, span, annotation, annotated)
}
}
impl AttrProcMacro for ExpandEnsures {
fn expand<'cx>(
&self,
ecx: &'cx mut ExtCtxt<'_>,
span: Span,
annotation: TokenStream,
annotated: TokenStream,
) -> Result<TokenStream, ErrorGuaranteed> {
expand_ensures_tts(ecx, span, annotation, annotated)
}
}
fn expand_injecting_circa_where_clause(
_ecx: &mut ExtCtxt<'_>,
attr_span: Span,
annotated: TokenStream,
inject: impl FnOnce(&mut Vec<TokenTree>) -> Result<(), ErrorGuaranteed>,
) -> Result<TokenStream, ErrorGuaranteed> {
let mut new_tts = Vec::with_capacity(annotated.len());
let mut cursor = annotated.into_trees();
// Find the `fn name<G,...>(x:X,...)` and inject the AST contract forms right after
// the formal parameters (and return type if any).
while let Some(tt) = cursor.next_ref() {
new_tts.push(tt.clone());
if let TokenTree::Token(tok, _) = tt
&& tok.is_ident_named(kw::Fn)
{
break;
}
}
// Found the `fn` keyword, now find the formal parameters.
//
// FIXME: can this fail if you have parentheticals in a generics list, like `fn foo<F: Fn(X) -> Y>` ?
while let Some(tt) = cursor.next_ref() {
new_tts.push(tt.clone());
if let TokenTree::Delimited(_, _, token::Delimiter::Parenthesis, _) = tt {
break;
}
if let TokenTree::Token(token::Token { kind: token::TokenKind::Semi, .. }, _) = tt {
panic!("contract attribute applied to fn without parameter list.");
}
}
// There *might* be a return type declaration (and figuring out where that ends would require
// parsing an arbitrary type expression, e.g. `-> Foo<args ...>`
//
// Instead of trying to figure that out, scan ahead and look for the first occurence of a
// `where`, a `{ ... }`, or a `;`.
//
// FIXME: this might still fall into a trap for something like `-> Ctor<T, const { 0 }>`. I
// *think* such cases must be under a Delimited (e.g. `[T; { N }]` or have the braced form
// prefixed by e.g. `const`, so we should still be able to filter them out without having to
// parse the type expression itself. But rather than try to fix things with hacks like that,
// time might be better spent extending the attribute expander to suport tt-annotation atop
// ast-annotated, which would be an elegant way to sidestep all of this.
let mut opt_next_tt = cursor.next_ref();
while let Some(next_tt) = opt_next_tt {
if let TokenTree::Token(tok, _) = next_tt
&& tok.is_ident_named(kw::Where)
{
break;
}
if let TokenTree::Delimited(_, _, token::Delimiter::Brace, _) = next_tt {
break;
}
if let TokenTree::Token(token::Token { kind: token::TokenKind::Semi, .. }, _) = next_tt {
break;
}
// for anything else, transcribe the tt and keep looking.
new_tts.push(next_tt.clone());
opt_next_tt = cursor.next_ref();
continue;
}
// At this point, we've transcribed everything from the `fn` through the formal parameter list
// and return type declaration, (if any), but `tt` itself has *not* been transcribed.
//
// Now inject the AST contract form.
//
// FIXME: this kind of manual token tree munging does not have significant precedent among
// rustc builtin macros, probably because most builtin macros use direct AST manipulation to
// accomplish similar goals. But since our attributes need to take arbitrary expressions, and
// our attribute infrastructure does not yet support mixing a token-tree annotation with an AST
// annotated, we end up doing token tree manipulation.
inject(&mut new_tts)?;
// Above we injected the internal AST requires/ensures contruct. Now copy over all the other
// token trees.
if let Some(tt) = opt_next_tt {
new_tts.push(tt.clone());
}
while let Some(tt) = cursor.next_ref() {
new_tts.push(tt.clone());
}
Ok(TokenStream::new(new_tts))
}
fn expand_requires_tts(
_ecx: &mut ExtCtxt<'_>,
attr_span: Span,
annotation: TokenStream,
annotated: TokenStream,
) -> Result<TokenStream, ErrorGuaranteed> {
expand_injecting_circa_where_clause(_ecx, attr_span, annotated, |new_tts| {
new_tts.push(TokenTree::Token(
token::Token::from_ast_ident(Ident::new(kw::RustcContractRequires, attr_span)),
Spacing::Joint,
));
new_tts.push(TokenTree::Token(
token::Token::new(token::TokenKind::OrOr, attr_span),
Spacing::Alone,
));
new_tts.push(TokenTree::Delimited(
DelimSpan::from_single(attr_span),
DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
token::Delimiter::Parenthesis,
annotation,
));
Ok(())
})
}
fn expand_ensures_tts(
_ecx: &mut ExtCtxt<'_>,
attr_span: Span,
annotation: TokenStream,
annotated: TokenStream,
) -> Result<TokenStream, ErrorGuaranteed> {
expand_injecting_circa_where_clause(_ecx, attr_span, annotated, |new_tts| {
new_tts.push(TokenTree::Token(
token::Token::from_ast_ident(Ident::new(kw::RustcContractEnsures, attr_span)),
Spacing::Joint,
));
new_tts.push(TokenTree::Delimited(
DelimSpan::from_single(attr_span),
DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
token::Delimiter::Parenthesis,
annotation,
));
Ok(())
})
}

View file

@ -55,6 +55,7 @@ mod trace_macros;
pub mod asm; pub mod asm;
pub mod cmdline_attrs; pub mod cmdline_attrs;
pub mod contracts;
pub mod proc_macro_harness; pub mod proc_macro_harness;
pub mod standard_library_imports; pub mod standard_library_imports;
pub mod test_harness; pub mod test_harness;
@ -137,4 +138,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
let client = proc_macro::bridge::client::Client::expand1(proc_macro::quote); let client = proc_macro::bridge::client::Client::expand1(proc_macro::quote);
register(sym::quote, SyntaxExtensionKind::Bang(Box::new(BangProcMacro { client }))); register(sym::quote, SyntaxExtensionKind::Bang(Box::new(BangProcMacro { client })));
let requires = SyntaxExtensionKind::Attr(Box::new(contracts::ExpandRequires));
register(sym::contracts_requires, requires);
let ensures = SyntaxExtensionKind::Attr(Box::new(contracts::ExpandEnsures));
register(sym::contracts_ensures, ensures);
} }

View file

@ -682,6 +682,8 @@ symbols! {
contract_check_ensures, contract_check_ensures,
contract_check_requires, contract_check_requires,
contract_checks, contract_checks,
contracts_ensures,
contracts_requires,
convert_identity, convert_identity,
copy, copy,
copy_closures, copy_closures,

View file

@ -1,5 +1,10 @@
//! Unstable module containing the unstable contracts lang items and attribute macros. //! Unstable module containing the unstable contracts lang items and attribute macros.
#[cfg(not(bootstrap))]
pub use crate::macros::builtin::contracts_ensures as ensures;
#[cfg(not(bootstrap))]
pub use crate::macros::builtin::contracts_requires as requires;
/// Emitted by rustc as a desugaring of `#[requires(PRED)] fn foo(x: X) { ... }` /// Emitted by rustc as a desugaring of `#[requires(PRED)] fn foo(x: X) { ... }`
/// into: `fn foo(x: X) { check_requires(|| PRED) ... }` /// into: `fn foo(x: X) { check_requires(|| PRED) ... }`
#[cfg(not(bootstrap))] #[cfg(not(bootstrap))]

View file

@ -1777,6 +1777,32 @@ pub(crate) mod builtin {
/* compiler built-in */ /* compiler built-in */
} }
/// Attribute macro applied to a function to give it a post-condition.
///
/// The attribute carries an argument token-tree which is
/// eventually parsed as a unary closure expression that is
/// invoked on a reference to the return value.
#[cfg(not(bootstrap))]
#[unstable(feature = "rustc_contracts", issue = "none")]
#[allow_internal_unstable(core_intrinsics)]
#[rustc_builtin_macro]
pub macro contracts_ensures($item:item) {
/* compiler built-in */
}
/// Attribute macro applied to a function to give it a precondition.
///
/// The attribute carries an argument token-tree which is
/// eventually parsed as an boolean expression with access to the
/// function's formal parameters
#[cfg(not(bootstrap))]
#[unstable(feature = "rustc_contracts", issue = "none")]
#[allow_internal_unstable(core_intrinsics)]
#[rustc_builtin_macro]
pub macro contracts_requires($item:item) {
/* compiler built-in */
}
/// Attribute macro applied to a function to register it as a handler for allocation failure. /// Attribute macro applied to a function to register it as a handler for allocation failure.
/// ///
/// See also [`std::alloc::handle_alloc_error`](../../../std/alloc/fn.handle_alloc_error.html). /// See also [`std::alloc::handle_alloc_error`](../../../std/alloc/fn.handle_alloc_error.html).

View file

@ -0,0 +1,44 @@
//@ revisions: unchk_pass unchk_fail_pre unchk_fail_post chk_pass chk_fail_pre chk_fail_post
//
//@ [unchk_pass] run-pass
//@ [unchk_fail_pre] run-pass
//@ [unchk_fail_post] run-pass
//@ [chk_pass] run-pass
//
//@ [chk_fail_pre] run-fail
//@ [chk_fail_post] run-fail
//
//@ [unchk_pass] compile-flags: -Zcontract-checks=no
//@ [unchk_fail_pre] compile-flags: -Zcontract-checks=no
//@ [unchk_fail_post] compile-flags: -Zcontract-checks=no
//
//@ [chk_pass] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_pre] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_post] compile-flags: -Zcontract-checks=yes
#![feature(rustc_contracts)]
#[core::contracts::requires(x.baz > 0)]
#[core::contracts::ensures(|ret| *ret > 100)]
fn nest(x: Baz) -> i32
{
loop {
return x.baz + 50;
}
}
struct Baz { baz: i32 }
const BAZ_PASS_PRE_POST: Baz = Baz { baz: 100 };
#[cfg(any(unchk_fail_post, chk_fail_post))]
const BAZ_FAIL_POST: Baz = Baz { baz: 10 };
#[cfg(any(unchk_fail_pre, chk_fail_pre))]
const BAZ_FAIL_PRE: Baz = Baz { baz: -10 };
fn main() {
assert_eq!(nest(BAZ_PASS_PRE_POST), 150);
#[cfg(any(unchk_fail_pre, chk_fail_pre))]
nest(BAZ_FAIL_PRE);
#[cfg(any(unchk_fail_post, chk_fail_post))]
nest(BAZ_FAIL_POST);
}

View file

@ -0,0 +1,42 @@
//@ revisions: unchk_pass unchk_fail_pre unchk_fail_post chk_pass chk_fail_pre chk_fail_post
//
//@ [unchk_pass] run-pass
//@ [unchk_fail_pre] run-pass
//@ [unchk_fail_post] run-pass
//@ [chk_pass] run-pass
//
//@ [chk_fail_pre] run-fail
//@ [chk_fail_post] run-fail
//
//@ [unchk_pass] compile-flags: -Zcontract-checks=no
//@ [unchk_fail_pre] compile-flags: -Zcontract-checks=no
//@ [unchk_fail_post] compile-flags: -Zcontract-checks=no
//
//@ [chk_pass] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_pre] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_post] compile-flags: -Zcontract-checks=yes
#![feature(rustc_contracts)]
#[core::contracts::requires(x.baz > 0)]
#[core::contracts::ensures(|ret| *ret > 100)]
fn tail(x: Baz) -> i32
{
x.baz + 50
}
struct Baz { baz: i32 }
const BAZ_PASS_PRE_POST: Baz = Baz { baz: 100 };
#[cfg(any(unchk_fail_post, chk_fail_post))]
const BAZ_FAIL_POST: Baz = Baz { baz: 10 };
#[cfg(any(unchk_fail_pre, chk_fail_pre))]
const BAZ_FAIL_PRE: Baz = Baz { baz: -10 };
fn main() {
assert_eq!(tail(BAZ_PASS_PRE_POST), 150);
#[cfg(any(unchk_fail_pre, chk_fail_pre))]
tail(BAZ_FAIL_PRE);
#[cfg(any(unchk_fail_post, chk_fail_post))]
tail(BAZ_FAIL_POST);
}

View file

@ -0,0 +1,48 @@
//@ revisions: unchk_pass chk_pass chk_fail_try chk_fail_ret chk_fail_yeet
//
//@ [unchk_pass] run-pass
//@ [chk_pass] run-pass
//@ [chk_fail_try] run-fail
//@ [chk_fail_ret] run-fail
//@ [chk_fail_yeet] run-fail
//
//@ [unchk_pass] compile-flags: -Zcontract-checks=no
//@ [chk_pass] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_try] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_ret] compile-flags: -Zcontract-checks=yes
//@ [chk_fail_yeet] compile-flags: -Zcontract-checks=yes
//! This test ensures that ensures clauses are checked for different return points of a function.
#![feature(rustc_contracts)]
#![feature(yeet_expr)]
/// This ensures will fail in different return points depending on the input.
#[core::contracts::ensures(|ret: &Option<u32>| ret.is_some())]
fn try_sum(x: u32, y: u32, z: u32) -> Option<u32> {
// Use Yeet to return early.
if x == u32::MAX && (y > 0 || z > 0) { do yeet }
// Use `?` to early return.
let partial = x.checked_add(y)?;
// Explicitly use `return` clause.
if u32::MAX - partial < z {
return None;
}
Some(partial + z)
}
fn main() {
// This should always succeed
assert_eq!(try_sum(0, 1, 2), Some(3));
#[cfg(any(unchk_pass, chk_fail_yeet))]
assert_eq!(try_sum(u32::MAX, 1, 1), None);
#[cfg(any(unchk_pass, chk_fail_try))]
assert_eq!(try_sum(u32::MAX - 10, 12, 0), None);
#[cfg(any(unchk_pass, chk_fail_ret))]
assert_eq!(try_sum(u32::MAX - 10, 2, 100), None);
}

View file

@ -0,0 +1,14 @@
//@ run-pass
//@ compile-flags: -Zcontract-checks=yes
#![feature(rustc_contracts)]
#[core::contracts::ensures(|ret| *ret > 0)]
fn outer() -> i32 {
let inner_closure = || -> i32 { 0 };
inner_closure();
10
}
fn main() {
outer();
}

View file

@ -0,0 +1,16 @@
//@ run-pass
//@ compile-flags: -Zcontract-checks=yes
#![feature(rustc_contracts)]
struct Outer { outer: std::cell::Cell<i32> }
#[core::contracts::requires(x.outer.get() > 0)]
fn outer(x: Outer) {
let inner_closure = || { };
x.outer.set(0);
inner_closure();
}
fn main() {
outer(Outer { outer: 1.into() });
}