1
Fork 0

Rollup merge of #65226 - ssomers:master, r=bluss

BTreeSet symmetric_difference & union optimized

No scalability changes, but:
- Grew the cmp_opt function (shared by symmetric_difference & union) into a MergeIter, with less memory overhead than the pairs of Peekable iterators now, speeding up ~20% on my machine (not so clear on Travis though, I actually switched it off there because it wasn't consistent about identical code). Mainly meant to improve readability by sharing code, though it does end up using more lines of code. Extending and reusing the MergeIter in btree_map might be better, but I'm not sure that's possible or desirable. This MergeIter probably pretends to be more generic than it is, yet doesn't declare to be an iterator because there's no need to, it's only there to help construct genuine iterators SymmetricDifference & Union.
- Compact the code of #64820 by moving if/else into match guards.

r? @bluss
This commit is contained in:
Mazdak Farrokhzad 2019-10-19 16:00:53 +02:00 committed by GitHub
commit a6b5c80dbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 144 additions and 121 deletions

View file

@ -2,7 +2,7 @@
// to TreeMap
use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::Ordering::{Less, Greater, Equal};
use core::cmp::{max, min};
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
iter: btree_map::Range<'a, T, ()>,
}
/// Core of SymmetricDifference and Union.
/// More efficient than btree.map.MergeIter,
/// and crucially for SymmetricDifference, nexts() reports on both sides.
#[derive(Clone)]
struct MergeIterInner<I>
where I: Iterator,
I::Item: Copy,
{
a: I,
b: I,
peeked: Option<MergeIterPeeked<I>>,
}
#[derive(Copy, Clone, Debug)]
enum MergeIterPeeked<I: Iterator> {
A(I::Item),
B(I::Item),
}
impl<I> MergeIterInner<I>
where I: ExactSizeIterator + FusedIterator,
I::Item: Copy + Ord,
{
fn new(a: I, b: I) -> Self {
MergeIterInner { a, b, peeked: None }
}
fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
let mut a_next = match self.peeked {
Some(MergeIterPeeked::A(next)) => Some(next),
_ => self.a.next(),
};
let mut b_next = match self.peeked {
Some(MergeIterPeeked::B(next)) => Some(next),
_ => self.b.next(),
};
let ord = match (a_next, b_next) {
(None, None) => Equal,
(_, None) => Less,
(None, _) => Greater,
(Some(a1), Some(b1)) => a1.cmp(&b1),
};
self.peeked = match ord {
Less => b_next.take().map(MergeIterPeeked::B),
Equal => None,
Greater => a_next.take().map(MergeIterPeeked::A),
};
(a_next, b_next)
}
fn lens(&self) -> (usize, usize) {
match self.peeked {
Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
_ => (self.a.len(), self.b.len()),
}
}
}
impl<I> Debug for MergeIterInner<I>
where I: Iterator + Debug,
I::Item: Copy + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MergeIterInner")
.field(&self.a)
.field(&self.b)
.finish()
}
}
/// A lazy iterator producing elements in the difference of `BTreeSet`s.
///
/// This `struct` is created by the [`difference`] method on [`BTreeSet`].
@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
pub struct Difference<'a, T: 'a> {
inner: DifferenceInner<'a, T>,
}
#[derive(Debug)]
enum DifferenceInner<'a, T: 'a> {
Stitch {
// iterate all of self and some of other, spotting matches along the way
@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => f
.debug_tuple("Difference")
.field(&self_iter)
.field(&other_iter)
.finish(),
DifferenceInner::Search {
self_iter,
other_set: _,
} => f.debug_tuple("Difference").field(&self_iter).finish(),
DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
}
f.debug_tuple("Difference").field(&self.inner).finish()
}
}
@ -163,18 +221,12 @@ impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
#[stable(feature = "rust1", since = "1.0.0")]
pub struct SymmetricDifference<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SymmetricDifference")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("SymmetricDifference").field(&self.0).finish()
}
}
@ -189,6 +241,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
pub struct Intersection<'a, T: 'a> {
inner: IntersectionInner<'a, T>,
}
#[derive(Debug)]
enum IntersectionInner<'a, T: 'a> {
Stitch {
// iterate similarly sized sets jointly, spotting matches along the way
@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
IntersectionInner::Stitch {
a,
b,
} => f
.debug_tuple("Intersection")
.field(&a)
.field(&b)
.finish(),
IntersectionInner::Search {
small_iter,
large_set: _,
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
IntersectionInner::Answer(answer) => {
f.debug_tuple("Intersection").field(&answer).finish()
}
}
f.debug_tuple("Intersection").field(&self.inner).finish()
}
}
@ -234,18 +271,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`union`]: struct.BTreeSet.html#method.union
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Union<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Union")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("Union").field(&self.0).finish()
}
}
@ -355,19 +386,16 @@ impl<T: Ord> BTreeSet<T> {
self_iter.next_back();
DifferenceInner::Iterate(self_iter)
}
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
}
} else {
DifferenceInner::Stitch {
}
_ => DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
}
}
}
},
},
}
}
@ -396,10 +424,7 @@ impl<T: Ord> BTreeSet<T> {
pub fn symmetric_difference<'a>(&'a self,
other: &'a BTreeSet<T>)
-> SymmetricDifference<'a, T> {
SymmetricDifference {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
}
/// Visits the values representing the intersection,
@ -447,24 +472,22 @@ impl<T: Ord> BTreeSet<T> {
(Greater, _) | (_, Less) => IntersectionInner::Answer(None),
(Equal, _) => IntersectionInner::Answer(Some(self_min)),
(_, Equal) => IntersectionInner::Answer(Some(self_max)),
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: self.iter(),
large_set: other,
}
} else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
}
_ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: other.iter(),
large_set: self,
}
} else {
IntersectionInner::Stitch {
}
_ => IntersectionInner::Stitch {
a: self.iter(),
b: other.iter(),
}
}
}
},
},
}
}
@ -489,10 +512,7 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
Union {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
Union(MergeIterInner::new(self.iter(), other.iter()))
}
/// Clears the set, removing all values.
@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> {
#[stable(feature = "fused", since = "1.26.0")]
impl<T> FusedIterator for Range<'_, T> {}
/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
match (x, y) {
(None, _) => short,
(_, None) => long,
(Some(x1), Some(y1)) => x1.cmp(y1),
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Difference<'_, T> {
fn clone(&self) -> Self {
@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for SymmetricDifference<'_, T> {
fn clone(&self) -> Self {
SymmetricDifference {
a: self.a.clone(),
b: self.b.clone(),
}
SymmetricDifference(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
fn next(&mut self) -> Option<&'a T> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => return self.a.next(),
Equal => {
self.a.next();
self.b.next();
}
Greater => return self.b.next(),
let (a_next, b_next) = self.0.nexts();
if a_next.and(b_next).is_none() {
return a_next.or(b_next);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.a.len() + self.b.len()))
let (a_len, b_len) = self.0.lens();
// No checked_add, because even if a and b refer to the same set,
// and T is an empty type, the storage overhead of sets limits
// the number of elements to less than half the range of usize.
(0, Some(a_len + b_len))
}
}
@ -1311,7 +1319,7 @@ impl<T> Clone for Intersection<'_, T> {
small_iter: small_iter.clone(),
large_set,
},
IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
},
}
}
@ -1365,10 +1373,7 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Union<'_, T> {
fn clone(&self) -> Self {
Union {
a: self.a.clone(),
b: self.b.clone(),
}
Union(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
@ -1376,19 +1381,13 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => self.a.next(),
Equal => {
self.b.next();
self.a.next()
}
Greater => self.b.next(),
}
let (a_next, b_next) = self.0.nexts();
a_next.or(b_next)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let a_len = self.a.len();
let b_len = self.b.len();
let (a_len, b_len) = self.0.lens();
// No checked_add - see SymmetricDifference::size_hint.
(max(a_len, b_len), Some(a_len + b_len))
}
}

View file

@ -221,6 +221,18 @@ fn test_symmetric_difference() {
&[-2, 1, 5, 11, 14, 22]);
}
#[test]
fn test_symmetric_difference_size_hint() {
let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
let mut iter = x.symmetric_difference(&y);
assert_eq!(iter.size_hint(), (0, Some(5)));
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.size_hint(), (0, Some(4)));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.size_hint(), (0, Some(1)));
}
#[test]
fn test_union() {
fn check_union(a: &[i32], b: &[i32], expected: &[i32]) {
@ -235,6 +247,18 @@ fn test_union() {
&[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]);
}
#[test]
fn test_union_size_hint() {
let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
let mut iter = x.union(&y);
assert_eq!(iter.size_hint(), (3, Some(5)));
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.size_hint(), (2, Some(4)));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.size_hint(), (1, Some(2)));
}
#[test]
// Only tests the simple function definition with respect to intersection
fn test_is_disjoint() {
@ -244,7 +268,7 @@ fn test_is_disjoint() {
}
#[test]
// Also tests the trivial function definition of is_superset
// Also implicitly tests the trivial function definition of is_superset
fn test_is_subset() {
fn is_subset(a: &[i32], b: &[i32]) -> bool {
let set_a = a.iter().collect::<BTreeSet<_>>();