1
Fork 0

pattern lowering: make sure we never call user-defined PartialEq instances

This commit is contained in:
Ralf Jung 2024-07-13 18:03:05 +02:00
parent e613bc92a1
commit 86ce911f90
3 changed files with 35 additions and 39 deletions

View file

@ -783,16 +783,13 @@ pub enum PatKind<'tcx> {
}, },
/// One of the following: /// One of the following:
/// * `&str` (represented as a valtree), which will be handled as a string pattern and thus /// * `&str`/`&[u8]` (represented as a valtree), which will be handled as a string/slice pattern
/// exhaustiveness checking will detect if you use the same string twice in different /// and thus exhaustiveness checking will detect if you use the same string/slice twice in
/// patterns. /// different patterns.
/// * integer, bool, char or float (represented as a valtree), which will be handled by /// * integer, bool, char or float (represented as a valtree), which will be handled by
/// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are /// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are
/// much simpler. /// much simpler.
/// * Opaque constants (represented as `mir::ConstValue`), that must not be matched /// * `String`, if `string_deref_patterns` is enabled.
/// structurally. So anything that does not derive `PartialEq` and `Eq`.
///
/// These are always compared with the matched place using (the semantics of) `PartialEq`.
Constant { Constant {
value: mir::Const<'tcx>, value: mir::Const<'tcx>,
}, },

View file

@ -144,7 +144,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
&& tcx.is_lang_item(def.did(), LangItem::String) && tcx.is_lang_item(def.did(), LangItem::String)
{ {
if !tcx.features().string_deref_patterns { if !tcx.features().string_deref_patterns {
bug!( span_bug!(
test.span,
"matching on `String` went through without enabling string_deref_patterns" "matching on `String` went through without enabling string_deref_patterns"
); );
} }
@ -432,40 +433,28 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
} }
} }
match *ty.kind() { // Figure out the type on which we are calling `PartialEq`. This involves an extra wrapping
ty::Ref(_, deref_ty, _) => ty = deref_ty, // reference: we can only compare two `&T`, and then compare_ty will be `T`.
_ => { // Make sure that we do *not* call any user-defined code here.
// non_scalar_compare called on non-reference type // The only types that can end up here are string and byte literals,
let temp = self.temp(ty, source_info.span); // which have their comparison defined in `core`.
self.cfg.push_assign(block, source_info, temp, Rvalue::Use(expect)); // (Interestingly this means that exhaustiveness analysis relies, for soundness,
let ref_ty = Ty::new_imm_ref(self.tcx, self.tcx.lifetimes.re_erased, ty); // on the `PartialEq` impls for `str` and `[u8]` to b correct!)
let ref_temp = self.temp(ref_ty, source_info.span); let compare_ty = match *ty.kind() {
ty::Ref(_, deref_ty, _)
self.cfg.push_assign( if deref_ty == self.tcx.types.str_ || deref_ty != self.tcx.types.u8 =>
block, {
source_info, deref_ty
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, temp),
);
expect = Operand::Move(ref_temp);
let ref_temp = self.temp(ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
ref_temp,
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, val),
);
val = ref_temp;
} }
} _ => span_bug!(source_info.span, "invalid type for non-scalar compare: {}", ty),
};
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span)); let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
let method = trait_method( let method = trait_method(
self.tcx, self.tcx,
eq_def_id, eq_def_id,
sym::eq, sym::eq,
self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [ty, ty]), self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [compare_ty, compare_ty]),
); );
let bool_ty = self.tcx.types.bool; let bool_ty = self.tcx.types.bool;

View file

@ -462,7 +462,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
// This is a box pattern. // This is a box pattern.
ty::Adt(adt, ..) if adt.is_box() => Struct, ty::Adt(adt, ..) if adt.is_box() => Struct,
ty::Ref(..) => Ref, ty::Ref(..) => Ref,
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty), _ => span_bug!(
pat.span,
"pattern has unexpected type: pat: {:?}, ty: {:?}",
pat.kind,
ty.inner()
),
}; };
} }
PatKind::DerefPattern { .. } => { PatKind::DerefPattern { .. } => {
@ -518,7 +523,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
.map(|ipat| self.lower_pat(&ipat.pattern).at_index(ipat.field.index())) .map(|ipat| self.lower_pat(&ipat.pattern).at_index(ipat.field.index()))
.collect(); .collect();
} }
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty), _ => span_bug!(
pat.span,
"pattern has unexpected type: pat: {:?}, ty: {}",
pat.kind,
ty.inner()
),
} }
} }
PatKind::Constant { value } => { PatKind::Constant { value } => {
@ -663,7 +673,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
} }
} }
} }
_ => bug!("invalid type for range pattern: {}", ty.inner()), _ => span_bug!(pat.span, "invalid type for range pattern: {}", ty.inner()),
}; };
fields = vec![]; fields = vec![];
arity = 0; arity = 0;
@ -674,7 +684,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
Some(length.eval_target_usize(cx.tcx, cx.param_env) as usize) Some(length.eval_target_usize(cx.tcx, cx.param_env) as usize)
} }
ty::Slice(_) => None, ty::Slice(_) => None,
_ => span_bug!(pat.span, "bad ty {:?} for slice pattern", ty), _ => span_bug!(pat.span, "bad ty {} for slice pattern", ty.inner()),
}; };
let kind = if slice.is_some() { let kind = if slice.is_some() {
SliceKind::VarLen(prefix.len(), suffix.len()) SliceKind::VarLen(prefix.len(), suffix.len())