1
Fork 0

Address review comments

Clean up code and add comments.
Use InlineConstant to wrap range patterns.
This commit is contained in:
Matthew Jasper 2023-10-09 16:31:44 +00:00
parent 5cc83fd4a5
commit 8aea0e9590
7 changed files with 57 additions and 63 deletions

View file

@ -765,8 +765,19 @@ pub enum PatKind<'tcx> {
value: mir::Const<'tcx>, value: mir::Const<'tcx>,
}, },
/// Inline constant found while lowering a pattern.
InlineConstant { InlineConstant {
value: mir::UnevaluatedConst<'tcx>, /// [LocalDefId] of the constant, we need this so that we have a
/// reference that can be used by unsafety checking to visit nested
/// unevaluated constants.
def: LocalDefId,
/// If the inline constant is used in a range pattern, this subpattern
/// represents the range (if both ends are inline constants, there will
/// be multiple InlineConstant wrappers).
///
/// Otherwise, the actual pattern that the constant lowered to. As with
/// other constants, inline constants are matched structurally where
/// possible.
subpattern: Box<Pat<'tcx>>, subpattern: Box<Pat<'tcx>>,
}, },
@ -930,7 +941,7 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
write!(f, "{subpattern}") write!(f, "{subpattern}")
} }
PatKind::Constant { value } => write!(f, "{value}"), PatKind::Constant { value } => write!(f, "{value}"),
PatKind::InlineConstant { value: _, ref subpattern } => { PatKind::InlineConstant { def: _, ref subpattern } => {
write!(f, "{} (from inline const)", subpattern) write!(f, "{} (from inline const)", subpattern)
} }
PatKind::Range(box PatRange { lo, hi, end }) => { PatKind::Range(box PatRange { lo, hi, end }) => {

View file

@ -233,7 +233,7 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
} }
} }
Constant { value: _ } => {} Constant { value: _ } => {}
InlineConstant { value: _, subpattern } => visitor.visit_pat(subpattern), InlineConstant { def: _, subpattern } => visitor.visit_pat(subpattern),
Range(_) => {} Range(_) => {}
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => { Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
for subpattern in prefix.iter() { for subpattern in prefix.iter() {

View file

@ -204,7 +204,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Err(match_pair) Err(match_pair)
} }
PatKind::InlineConstant { subpattern: ref pattern, value: _ } => { PatKind::InlineConstant { subpattern: ref pattern, def: _ } => {
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self)); candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));
Ok(()) Ok(())
@ -236,20 +236,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// pattern/_match.rs for another pertinent example of this pattern). // pattern/_match.rs for another pertinent example of this pattern).
// //
// Also, for performance, it's important to only do the second // Also, for performance, it's important to only do the second
// `try_eval_scalar_int` if necessary. // `try_to_bits` if necessary.
let lo = lo let lo = lo.try_to_bits(sz).unwrap() ^ bias;
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
if lo <= min { if lo <= min {
let hi = hi let hi = hi.try_to_bits(sz).unwrap() ^ bias;
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
if hi > max || hi == max && end == RangeEnd::Included { if hi > max || hi == max && end == RangeEnd::Included {
// Irrefutable pattern match. // Irrefutable pattern match.
return Ok(()); return Ok(());

View file

@ -67,7 +67,6 @@ fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty), thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
}; };
tcx.ensure().check_match(def);
// this must run before MIR dump, because // this must run before MIR dump, because
// "not all control paths return a value" is reported here. // "not all control paths return a value" is reported here.
// //

View file

@ -3,7 +3,7 @@ use crate::errors::*;
use rustc_middle::thir::visit::{self, Visitor}; use rustc_middle::thir::visit::{self, Visitor};
use rustc_hir as hir; use rustc_hir as hir;
use rustc_middle::mir::{BorrowKind, Const}; use rustc_middle::mir::BorrowKind;
use rustc_middle::thir::*; use rustc_middle::thir::*;
use rustc_middle::ty::print::with_no_trimmed_paths; use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt}; use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
@ -124,7 +124,8 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body. /// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
fn visit_inner_body(&mut self, def: LocalDefId) { fn visit_inner_body(&mut self, def: LocalDefId) {
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) { if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
let _ = self.tcx.ensure_with_value().mir_built(def); // Runs all other queries that depend on THIR.
self.tcx.ensure_with_value().mir_built(def);
let inner_thir = &inner_thir.steal(); let inner_thir = &inner_thir.steal();
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def); let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self }; let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
@ -279,23 +280,8 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat); visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt; self.inside_adt = old_inside_adt;
} }
PatKind::Range(range) => { PatKind::InlineConstant { def, .. } => {
if let Const::Unevaluated(c, _) = range.lo { self.visit_inner_body(*def);
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
if let Const::Unevaluated(c, _) = range.hi {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
}
PatKind::InlineConstant { value, .. } => {
let def_id = value.def.expect_local();
self.visit_inner_body(def_id);
} }
_ => { _ => {
visit::walk_pat(self, pat); visit::walk_pat(self, pat);
@ -808,7 +794,8 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
} }
let Ok((thir, expr)) = tcx.thir_body(def) else { return }; let Ok((thir, expr)) = tcx.thir_body(def) else { return };
let _ = tcx.ensure_with_value().mir_built(def); // Runs all other queries that depend on THIR.
tcx.ensure_with_value().mir_built(def);
let thir = &thir.steal(); let thir = &thir.steal();
// If `thir` is empty, a type error occurred, skip this body. // If `thir` is empty, a type error occurred, skip this body.
if thir.exprs.is_empty() { if thir.exprs.is_empty() {

View file

@ -27,6 +27,7 @@ use rustc_middle::ty::{
self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt, self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt,
TypeVisitableExt, UserType, TypeVisitableExt, UserType,
}; };
use rustc_span::def_id::LocalDefId;
use rustc_span::{ErrorGuaranteed, Span, Symbol}; use rustc_span::{ErrorGuaranteed, Span, Symbol};
use rustc_target::abi::{FieldIdx, Integer}; use rustc_target::abi::{FieldIdx, Integer};
@ -88,19 +89,21 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
fn lower_pattern_range_endpoint( fn lower_pattern_range_endpoint(
&mut self, &mut self,
expr: Option<&'tcx hir::Expr<'tcx>>, expr: Option<&'tcx hir::Expr<'tcx>>,
) -> Result<(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>), ErrorGuaranteed> { ) -> Result<
(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>, Option<LocalDefId>),
ErrorGuaranteed,
> {
match expr { match expr {
None => Ok((None, None)), None => Ok((None, None, None)),
Some(expr) => { Some(expr) => {
let (kind, ascr) = match self.lower_lit(expr) { let (kind, ascr, inline_const) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, value } => ( PatKind::InlineConstant { subpattern, def } => {
PatKind::Constant { value: Const::Unevaluated(value, subpattern.ty) }, (subpattern.kind, None, Some(def))
None,
),
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription))
} }
kind => (kind, None), PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription), None)
}
kind => (kind, None, None),
}; };
let value = if let PatKind::Constant { value } = kind { let value = if let PatKind::Constant { value } = kind {
value value
@ -110,7 +113,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
); );
return Err(self.tcx.sess.delay_span_bug(expr.span, msg)); return Err(self.tcx.sess.delay_span_bug(expr.span, msg));
}; };
Ok((Some(value), ascr)) Ok((Some(value), ascr, inline_const))
} }
} }
} }
@ -181,8 +184,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
return Err(self.tcx.sess.delay_span_bug(span, msg)); return Err(self.tcx.sess.delay_span_bug(span, msg));
} }
let (lo, lo_ascr) = self.lower_pattern_range_endpoint(lo_expr)?; let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr) = self.lower_pattern_range_endpoint(hi_expr)?; let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?;
let lo = lo.unwrap_or_else(|| { let lo = lo.unwrap_or_else(|| {
// Unwrap is ok because the type is known to be numeric. // Unwrap is ok because the type is known to be numeric.
@ -241,6 +244,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
}; };
} }
} }
for inline_const in [lo_inline, hi_inline] {
if let Some(def) = inline_const {
kind =
PatKind::InlineConstant { def, subpattern: Box::new(Pat { span, ty, kind }) };
}
}
Ok(kind) Ok(kind)
} }
@ -603,11 +612,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// const eval path below. // const eval path below.
// FIXME: investigate the performance impact of removing this. // FIXME: investigate the performance impact of removing this.
let lit_input = match expr.kind { let lit_input = match expr.kind {
hir::ExprKind::Lit(ref lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }), hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, ref expr) => match expr.kind { hir::ExprKind::Unary(hir::UnOp::Neg, expr) => match expr.kind {
hir::ExprKind::Lit(ref lit) => { hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: true }),
Some(LitToConstInput { lit: &lit.node, ty, neg: true })
}
_ => None, _ => None,
}, },
_ => None, _ => None,
@ -643,7 +650,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
span, span,
None, None,
); );
PatKind::InlineConstant { subpattern, value: uneval } PatKind::InlineConstant { subpattern, def: def_id }
} else { } else {
// If that fails, convert it to an opaque constant pattern. // If that fails, convert it to an opaque constant pattern.
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) { match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
@ -826,8 +833,8 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
PatKind::Deref { subpattern: subpattern.fold_with(folder) } PatKind::Deref { subpattern: subpattern.fold_with(folder) }
} }
PatKind::Constant { value } => PatKind::Constant { value }, PatKind::Constant { value } => PatKind::Constant { value },
PatKind::InlineConstant { value, subpattern: ref pattern } => { PatKind::InlineConstant { def, subpattern: ref pattern } => {
PatKind::InlineConstant { value, subpattern: pattern.fold_with(folder) } PatKind::InlineConstant { def, subpattern: pattern.fold_with(folder) }
} }
PatKind::Range(ref range) => PatKind::Range(range.clone()), PatKind::Range(ref range) => PatKind::Range(range.clone()),
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice { PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {

View file

@ -692,7 +692,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
} }
PatKind::Deref { subpattern } => { PatKind::Deref { subpattern } => {
print_indented!(self, "Deref { ", depth_lvl + 1); print_indented!(self, "Deref { ", depth_lvl + 1);
print_indented!(self, "subpattern: ", depth_lvl + 2); print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2); self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1); print_indented!(self, "}", depth_lvl + 1);
} }
@ -701,10 +701,10 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2); print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1); print_indented!(self, "}", depth_lvl + 1);
} }
PatKind::InlineConstant { value, subpattern } => { PatKind::InlineConstant { def, subpattern } => {
print_indented!(self, "InlineConstant {", depth_lvl + 1); print_indented!(self, "InlineConstant {", depth_lvl + 1);
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2); print_indented!(self, format!("def: {:?}", def), depth_lvl + 2);
print_indented!(self, "subpattern: ", depth_lvl + 2); print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2); self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1); print_indented!(self, "}", depth_lvl + 1);
} }