Rollup merge of #126553 - Nadrieril:expand-or-pat-into-above, r=matthewjasper

match lowering: expand or-candidates mixed with candidates above

This PR tweaks match lowering of or-patterns. Consider this:
```rust
match (x, y) {
    (1, true) => 1,
    (2, false) => 2,
    (1 | 2, true | false) => 3,
    (3 | 4, true | false) => 4,
    _ => 5,
}
```
One might hope that this can be compiled to a single `SwitchInt` on `x` followed by some boolean checks. Before this PR, we compile this to 3 `SwitchInt`s on `x`, because an arm that contains more than one or-pattern was compiled on its own. This PR groups branch `3` with the two branches above, getting us down to 2 `SwitchInt`s on `x`.

We can't in general expand or-patterns freely, because this interacts poorly with another optimization we do: or-pattern simplification. When an or-pattern doesn't involve bindings, we branch the success paths of all its alternatives to the same block. The drawback is that in a case like:
```rust
match (1, true) {
    (1 | 2, false) => unreachable!(),
    (2, _) => unreachable!(),
    _ => {}
}
```
if we used a single `SwitchInt`, by the time we test `false` we don't know whether we came from the `1` case or the `2` case, so we don't know where to go if `false` doesn't match.

Hence the limitation: we can process or-pattern alternatives alongside candidates that precede it, but not candidates that follow it. (Unless the or-pattern is the only remaining match pair of its candidate, in which case we can process it alongside whatever).

This PR allows the processing of or-pattern alternatives alongside candidates that precede it. One benefit is that we now process or-patterns in a single place in `mod.rs`.

r? ``@matthewjasper``
This commit is contained in:
León Orell Valerian Liehr 2024-06-19 09:52:00 +02:00 committed by GitHub
commit e111e99253
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 366 additions and 113 deletions

View file

@ -1074,12 +1074,9 @@ struct Candidate<'pat, 'tcx> {
// because that would break binding consistency.
subcandidates: Vec<Candidate<'pat, 'tcx>>,
/// ...and the guard must be evaluated if there is one.
/// ...and if there is a guard it must be evaluated; if it's `false` then branch to `otherwise_block`.
has_guard: bool,
/// If the guard is `false` then branch to `otherwise_block`.
otherwise_block: Option<BasicBlock>,
/// If the candidate matches, bindings and ascriptions must be established.
extra_data: PatternExtraData<'tcx>,
@ -1090,6 +1087,9 @@ struct Candidate<'pat, 'tcx> {
/// The block before the `bindings` have been established.
pre_binding_block: Option<BasicBlock>,
/// The block to branch to if the guard or a nested candidate fails to match.
otherwise_block: Option<BasicBlock>,
/// The earliest block that has only candidates >= this one as descendents. Used for false
/// edges, see the doc for [`Builder::match_expr`].
false_edge_start_block: Option<BasicBlock>,
@ -1364,56 +1364,105 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
otherwise_block: BasicBlock,
candidates: &mut [&mut Candidate<'pat, 'tcx>],
) {
let mut split_or_candidate = false;
for candidate in &mut *candidates {
if let [MatchPair { test_case: TestCase::Or { .. }, .. }] = &*candidate.match_pairs {
// Split a candidate in which the only match-pair is an or-pattern into multiple
// candidates. This is so that
//
// match x {
// 0 | 1 => { ... },
// 2 | 3 => { ... },
// }
//
// only generates a single switch.
let match_pair = candidate.match_pairs.pop().unwrap();
self.create_or_subcandidates(candidate, match_pair);
split_or_candidate = true;
// We process or-patterns here. If any candidate starts with an or-pattern, we have to
// expand the or-pattern before we can proceed further.
//
// We can't expand them freely however. The rule is: if the candidate has an or-pattern as
// its only remaining match pair, we can expand it freely. If it has other match pairs, we
// can expand it but we can't process more candidates after it.
//
// If we didn't stop, the `otherwise` cases could get mixed up. E.g. in the following,
// or-pattern simplification (in `merge_trivial_subcandidates`) makes it so the `1` and `2`
// cases branch to a same block (which then tests `false`). If we took `(2, _)` in the same
// set of candidates, when we reach the block that tests `false` we don't know whether we
// came from `1` or `2`, hence we can't know where to branch on failure.
// ```ignore(illustrative)
// match (1, true) {
// (1 | 2, false) => {},
// (2, _) => {},
// _ => {}
// }
// ```
//
// We therefore split the `candidates` slice in two, expand or-patterns in the first half,
// and process both halves separately.
let mut expand_until = 0;
for (i, candidate) in candidates.iter().enumerate() {
if matches!(
&*candidate.match_pairs,
[MatchPair { test_case: TestCase::Or { .. }, .. }, ..]
) {
expand_until = i + 1;
if candidate.match_pairs.len() > 1 {
break;
}
}
}
let (candidates_to_expand, remaining_candidates) = candidates.split_at_mut(expand_until);
ensure_sufficient_stack(|| {
if split_or_candidate {
// At least one of the candidates has been split into subcandidates.
// We need to change the candidate list to include those.
let mut new_candidates = Vec::new();
for candidate in candidates.iter_mut() {
candidate.visit_leaves(|leaf_candidate| new_candidates.push(leaf_candidate));
if candidates_to_expand.is_empty() {
// No candidates start with an or-pattern, we can continue.
self.match_expanded_candidates(
span,
scrutinee_span,
start_block,
otherwise_block,
remaining_candidates,
);
} else {
// Expand one level of or-patterns for each candidate in `candidates_to_expand`.
let mut expanded_candidates = Vec::new();
for candidate in candidates_to_expand.iter_mut() {
if let [MatchPair { test_case: TestCase::Or { .. }, .. }, ..] =
&*candidate.match_pairs
{
let or_match_pair = candidate.match_pairs.remove(0);
// Expand the or-pattern into subcandidates.
self.create_or_subcandidates(candidate, or_match_pair);
// Collect the newly created subcandidates.
for subcandidate in candidate.subcandidates.iter_mut() {
expanded_candidates.push(subcandidate);
}
} else {
expanded_candidates.push(candidate);
}
}
// Process the expanded candidates.
let remainder_start = self.cfg.start_new_block();
// There might be new or-patterns obtained from expanding the old ones, so we call
// `match_candidates` again.
self.match_candidates(
span,
scrutinee_span,
start_block,
otherwise_block,
&mut *new_candidates,
remainder_start,
expanded_candidates.as_mut_slice(),
);
for candidate in candidates {
self.merge_trivial_subcandidates(candidate);
// Simplify subcandidates and process any leftover match pairs.
for candidate in candidates_to_expand {
if !candidate.subcandidates.is_empty() {
self.finalize_or_candidate(span, scrutinee_span, candidate);
}
}
} else {
self.match_simplified_candidates(
// Process the remaining candidates.
self.match_candidates(
span,
scrutinee_span,
start_block,
remainder_start,
otherwise_block,
candidates,
remaining_candidates,
);
}
});
}
fn match_simplified_candidates(
/// Construct the decision tree for `candidates`. Caller must ensure that no candidate in
/// `candidates` starts with an or-pattern.
fn match_expanded_candidates(
&mut self,
span: Span,
scrutinee_span: Span,
@ -1438,7 +1487,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// The first candidate has satisfied all its match pairs; we link it up and continue
// with the remaining candidates.
start_block = self.select_matched_candidate(first, start_block);
self.match_simplified_candidates(
self.match_expanded_candidates(
span,
scrutinee_span,
start_block,
@ -1448,7 +1497,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
candidates => {
// The first candidate has some unsatisfied match pairs; we proceed to do more tests.
self.test_candidates_with_or(
self.test_candidates(
span,
scrutinee_span,
candidates,
@ -1495,16 +1544,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
candidate.pre_binding_block = Some(start_block);
let otherwise_block = self.cfg.start_new_block();
if candidate.has_guard {
// Create the otherwise block for this candidate, which is the
// pre-binding block for the next candidate.
candidate.otherwise_block = Some(otherwise_block);
}
// Create the otherwise block for this candidate, which is the
// pre-binding block for the next candidate.
candidate.otherwise_block = Some(otherwise_block);
otherwise_block
}
/// Tests a candidate where there are only or-patterns left to test, or
/// forwards to [Builder::test_candidates].
/// Simplify subcandidates and process any leftover match pairs. The candidate should have been
/// expanded with `create_or_subcandidates`.
///
/// Given a pattern `(P | Q, R | S)` we (in principle) generate a CFG like
/// so:
@ -1556,84 +1603,56 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
/// |
/// ...
/// ```
fn test_candidates_with_or(
fn finalize_or_candidate(
&mut self,
span: Span,
scrutinee_span: Span,
candidates: &mut [&mut Candidate<'_, 'tcx>],
start_block: BasicBlock,
otherwise_block: BasicBlock,
candidate: &mut Candidate<'_, 'tcx>,
) {
let (first_candidate, remaining_candidates) = candidates.split_first_mut().unwrap();
assert!(first_candidate.subcandidates.is_empty());
if !matches!(first_candidate.match_pairs[0].test_case, TestCase::Or { .. }) {
self.test_candidates(span, scrutinee_span, candidates, start_block, otherwise_block);
if candidate.subcandidates.is_empty() {
return;
}
let first_match_pair = first_candidate.match_pairs.remove(0);
let remaining_match_pairs = mem::take(&mut first_candidate.match_pairs);
let remainder_start = self.cfg.start_new_block();
// Test the alternatives of this or-pattern.
self.test_or_pattern(first_candidate, start_block, remainder_start, first_match_pair);
self.merge_trivial_subcandidates(candidate);
if !remaining_match_pairs.is_empty() {
if !candidate.match_pairs.is_empty() {
// If more match pairs remain, test them after each subcandidate.
// We could add them to the or-candidates before the call to `test_or_pattern` but this
// would make it impossible to detect simplifiable or-patterns. That would guarantee
// exponentially large CFGs for cases like `(1 | 2, 3 | 4, ...)`.
first_candidate.visit_leaves(|leaf_candidate| {
let mut last_otherwise = None;
candidate.visit_leaves(|leaf_candidate| {
last_otherwise = leaf_candidate.otherwise_block;
});
let remaining_match_pairs = mem::take(&mut candidate.match_pairs);
candidate.visit_leaves(|leaf_candidate| {
assert!(leaf_candidate.match_pairs.is_empty());
leaf_candidate.match_pairs.extend(remaining_match_pairs.iter().cloned());
let or_start = leaf_candidate.pre_binding_block.unwrap();
// In a case like `(a | b, c | d)`, if `a` succeeds and `c | d` fails, we know `(b,
// c | d)` will fail too. If there is no guard, we skip testing of `b` by branching
// directly to `remainder_start`. If there is a guard, we have to try `(b, c | d)`.
let or_otherwise = leaf_candidate.otherwise_block.unwrap_or(remainder_start);
self.test_candidates_with_or(
// In a case like `(P | Q, R | S)`, if `P` succeeds and `R | S` fails, we know `(Q,
// R | S)` will fail too. If there is no guard, we skip testing of `Q` by branching
// directly to `last_otherwise`. If there is a guard,
// `leaf_candidate.otherwise_block` can be reached by guard failure as well, so we
// can't skip `Q`.
let or_otherwise = if leaf_candidate.has_guard {
leaf_candidate.otherwise_block.unwrap()
} else {
last_otherwise.unwrap()
};
self.match_candidates(
span,
scrutinee_span,
&mut [leaf_candidate],
or_start,
or_otherwise,
&mut [leaf_candidate],
);
});
}
// Test the remaining candidates.
self.match_candidates(
span,
scrutinee_span,
remainder_start,
otherwise_block,
remaining_candidates,
);
}
#[instrument(skip(self, start_block, otherwise_block, candidate, match_pair), level = "debug")]
fn test_or_pattern<'pat>(
&mut self,
candidate: &mut Candidate<'pat, 'tcx>,
start_block: BasicBlock,
otherwise_block: BasicBlock,
match_pair: MatchPair<'pat, 'tcx>,
) {
let or_span = match_pair.pattern.span;
self.create_or_subcandidates(candidate, match_pair);
let mut or_candidate_refs: Vec<_> = candidate.subcandidates.iter_mut().collect();
self.match_candidates(
or_span,
or_span,
start_block,
otherwise_block,
&mut or_candidate_refs,
);
self.merge_trivial_subcandidates(candidate);
}
/// Given a match-pair that corresponds to an or-pattern, expand each subpattern into a new
/// subcandidate. Any candidate that has been expanded that way should be passed to
/// `merge_trivial_subcandidates` after its subcandidates have been processed.
/// `finalize_or_candidate` after its subcandidates have been processed.
fn create_or_subcandidates<'pat>(
&mut self,
candidate: &mut Candidate<'pat, 'tcx>,
@ -1651,8 +1670,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
/// Try to merge all of the subcandidates of the given candidate into one. This avoids
/// exponentially large CFGs in cases like `(1 | 2, 3 | 4, ...)`. The or-pattern should have
/// been expanded with `create_or_subcandidates`.
/// exponentially large CFGs in cases like `(1 | 2, 3 | 4, ...)`. The candidate should have been
/// expanded with `create_or_subcandidates`.
fn merge_trivial_subcandidates(&mut self, candidate: &mut Candidate<'_, 'tcx>) {
if candidate.subcandidates.is_empty() || candidate.has_guard {
// FIXME(or_patterns; matthewjasper) Don't give up if we have a guard.
@ -1664,6 +1683,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
subcandidate.subcandidates.is_empty() && subcandidate.extra_data.is_empty()
});
if can_merge {
let mut last_otherwise = None;
let any_matches = self.cfg.start_new_block();
let or_span = candidate.or_span.take().unwrap();
let source_info = self.source_info(or_span);
@ -1674,8 +1694,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
for subcandidate in mem::take(&mut candidate.subcandidates) {
let or_block = subcandidate.pre_binding_block.unwrap();
self.cfg.goto(or_block, source_info, any_matches);
last_otherwise = subcandidate.otherwise_block;
}
candidate.pre_binding_block = Some(any_matches);
assert!(last_otherwise.is_some());
candidate.otherwise_block = last_otherwise;
} else {
// Never subcandidates may have a set of bindings inconsistent with their siblings,
// which would break later code. So we filter them out. Note that we can't filter out

View file

@ -0,0 +1,24 @@
// skip-filecheck
// EMIT_MIR or_pattern.shortcut_second_or.SimplifyCfg-initial.after.mir
fn shortcut_second_or() {
// Check that after matching `0`, failing to match `2 | 3` skips trying to match `(1, 2 | 3)`.
match ((0, 0), 0) {
(x @ (0, _) | x @ (_, 1), y @ 2 | y @ 3) => {}
_ => {}
}
}
// EMIT_MIR or_pattern.single_switchint.SimplifyCfg-initial.after.mir
fn single_switchint() {
// Check how many `SwitchInt`s we do. In theory a single one is necessary.
match (1, true) {
(1, true) => 1,
(2, false) => 2,
(1 | 2, true | false) => 3,
(3 | 4, true | false) => 4,
_ => 5,
};
}
fn main() {}

View file

@ -0,0 +1,100 @@
// MIR for `shortcut_second_or` after SimplifyCfg-initial
fn shortcut_second_or() -> () {
let mut _0: ();
let mut _1: ((i32, i32), i32);
let mut _2: (i32, i32);
let _3: (i32, i32);
let _4: i32;
scope 1 {
debug x => _3;
debug y => _4;
}
bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = (const 0_i32, const 0_i32);
_1 = (move _2, const 0_i32);
StorageDead(_2);
PlaceMention(_1);
switchInt(((_1.0: (i32, i32)).0: i32)) -> [0: bb4, otherwise: bb2];
}
bb1: {
_0 = const ();
goto -> bb14;
}
bb2: {
switchInt(((_1.0: (i32, i32)).1: i32)) -> [1: bb3, otherwise: bb1];
}
bb3: {
switchInt((_1.1: i32)) -> [2: bb7, 3: bb8, otherwise: bb1];
}
bb4: {
switchInt((_1.1: i32)) -> [2: bb5, 3: bb6, otherwise: bb1];
}
bb5: {
falseEdge -> [real: bb10, imaginary: bb6];
}
bb6: {
falseEdge -> [real: bb11, imaginary: bb2];
}
bb7: {
falseEdge -> [real: bb12, imaginary: bb8];
}
bb8: {
falseEdge -> [real: bb13, imaginary: bb1];
}
bb9: {
_0 = const ();
StorageDead(_4);
StorageDead(_3);
goto -> bb14;
}
bb10: {
StorageLive(_3);
_3 = (_1.0: (i32, i32));
StorageLive(_4);
_4 = (_1.1: i32);
goto -> bb9;
}
bb11: {
StorageLive(_3);
_3 = (_1.0: (i32, i32));
StorageLive(_4);
_4 = (_1.1: i32);
goto -> bb9;
}
bb12: {
StorageLive(_3);
_3 = (_1.0: (i32, i32));
StorageLive(_4);
_4 = (_1.1: i32);
goto -> bb9;
}
bb13: {
StorageLive(_3);
_3 = (_1.0: (i32, i32));
StorageLive(_4);
_4 = (_1.1: i32);
goto -> bb9;
}
bb14: {
StorageDead(_1);
return;
}
}

View file

@ -0,0 +1,75 @@
// MIR for `single_switchint` after SimplifyCfg-initial
fn single_switchint() -> () {
let mut _0: ();
let _1: i32;
let mut _2: (i32, bool);
bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = (const 1_i32, const true);
PlaceMention(_2);
switchInt((_2.0: i32)) -> [1: bb2, 2: bb4, otherwise: bb1];
}
bb1: {
switchInt((_2.0: i32)) -> [3: bb8, 4: bb8, otherwise: bb7];
}
bb2: {
switchInt((_2.1: bool)) -> [0: bb6, otherwise: bb3];
}
bb3: {
falseEdge -> [real: bb9, imaginary: bb4];
}
bb4: {
switchInt((_2.1: bool)) -> [0: bb5, otherwise: bb6];
}
bb5: {
falseEdge -> [real: bb10, imaginary: bb6];
}
bb6: {
falseEdge -> [real: bb11, imaginary: bb1];
}
bb7: {
_1 = const 5_i32;
goto -> bb13;
}
bb8: {
falseEdge -> [real: bb12, imaginary: bb7];
}
bb9: {
_1 = const 1_i32;
goto -> bb13;
}
bb10: {
_1 = const 2_i32;
goto -> bb13;
}
bb11: {
_1 = const 3_i32;
goto -> bb13;
}
bb12: {
_1 = const 4_i32;
goto -> bb13;
}
bb13: {
StorageDead(_2);
StorageDead(_1);
_0 = const ();
return;
}
}

View file

@ -26,5 +26,6 @@ fn main() {
assert_eq!(or_at(Err(7)), 207);
assert_eq!(or_at(Err(8)), 8);
assert_eq!(or_at(Err(20)), 220);
assert_eq!(or_at(Err(34)), 134);
assert_eq!(or_at(Err(50)), 500);
}

View file

@ -1,5 +1,5 @@
error[E0308]: mismatched types
--> $DIR/inner-or-pat.rs:38:54
--> $DIR/inner-or-pat.rs:36:54
|
LL | match x {
| - this expression has type `&str`

View file

@ -1,5 +1,5 @@
error[E0408]: variable `x` is not bound in all patterns
--> $DIR/inner-or-pat.rs:53:37
--> $DIR/inner-or-pat.rs:51:37
|
LL | (x @ "red" | (x @ "blue" | "red")) => {
| - ^^^^^ pattern doesn't bind `x`

View file

@ -1,7 +1,5 @@
//@ revisions: or1 or2 or3 or4 or5
//@ revisions: or1 or3 or4
//@ [or1] run-pass
//@ [or2] run-pass
//@ [or5] run-pass
#![allow(unreachable_patterns)]
#![allow(unused_variables)]

View file

@ -1,21 +1,20 @@
//@ check-pass
//@ run-pass
#![deny(unreachable_patterns)]
fn main() {
match (3,42) {
(a,_) | (_,a) if a > 10 => {println!("{}", a)}
_ => ()
match (3, 42) {
(a, _) | (_, a) if a > 10 => {}
_ => unreachable!(),
}
match Some((3,42)) {
Some((a, _)) | Some((_, a)) if a > 10 => {println!("{}", a)}
_ => ()
match Some((3, 42)) {
Some((a, _)) | Some((_, a)) if a > 10 => {}
_ => unreachable!(),
}
match Some((3,42)) {
Some((a, _) | (_, a)) if a > 10 => {println!("{}", a)}
_ => ()
match Some((3, 42)) {
Some((a, _) | (_, a)) if a > 10 => {}
_ => unreachable!(),
}
}

View file

@ -42,6 +42,23 @@ fn search_old_style(target: (bool, bool, bool)) -> u32 {
}
}
// Check that a dummy or-pattern also leads to running the guard multiple times.
fn search_with_dummy(target: (bool, bool)) -> u32 {
let x = ((false, true), (false, true), ());
let mut guard_count = 0;
match x {
((a, _) | (_, a), (b, _) | (_, b), _ | _)
if {
guard_count += 1;
(a, b) == target
} =>
{
guard_count
}
_ => unreachable!(),
}
}
fn main() {
assert_eq!(search((false, false, false)), 1);
assert_eq!(search((false, false, true)), 2);
@ -60,4 +77,9 @@ fn main() {
assert_eq!(search_old_style((true, false, true)), 6);
assert_eq!(search_old_style((true, true, false)), 7);
assert_eq!(search_old_style((true, true, true)), 8);
assert_eq!(search_with_dummy((false, false)), 1);
assert_eq!(search_with_dummy((false, true)), 3);
assert_eq!(search_with_dummy((true, false)), 5);
assert_eq!(search_with_dummy((true, true)), 7);
}

View file

@ -0,0 +1,11 @@
//@ run-pass
#[allow(unreachable_patterns)]
fn main() {
// Test that we don't naively sort the two `2`s together and confuse the failure paths.
match (1, true) {
(1 | 2, false | false) => unreachable!(),
(2, _) => unreachable!(),
_ => {}
}
}