BTreeSet symmetric_difference & union optimized, cleaned
This commit is contained in:
parent
3da6836cc9
commit
56974329d1
2 changed files with 144 additions and 121 deletions
|
@ -2,7 +2,7 @@
|
|||
// to TreeMap
|
||||
|
||||
use core::borrow::Borrow;
|
||||
use core::cmp::Ordering::{self, Less, Greater, Equal};
|
||||
use core::cmp::Ordering::{Less, Greater, Equal};
|
||||
use core::cmp::{max, min};
|
||||
use core::fmt::{self, Debug};
|
||||
use core::iter::{Peekable, FromIterator, FusedIterator};
|
||||
|
@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
|
|||
iter: btree_map::Range<'a, T, ()>,
|
||||
}
|
||||
|
||||
/// Core of SymmetricDifference and Union.
|
||||
/// More efficient than btree.map.MergeIter,
|
||||
/// and crucially for SymmetricDifference, nexts() reports on both sides.
|
||||
#[derive(Clone)]
|
||||
struct MergeIterInner<I>
|
||||
where I: Iterator,
|
||||
I::Item: Copy,
|
||||
{
|
||||
a: I,
|
||||
b: I,
|
||||
peeked: Option<MergeIterPeeked<I>>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
enum MergeIterPeeked<I: Iterator> {
|
||||
A(I::Item),
|
||||
B(I::Item),
|
||||
}
|
||||
|
||||
impl<I> MergeIterInner<I>
|
||||
where I: ExactSizeIterator + FusedIterator,
|
||||
I::Item: Copy + Ord,
|
||||
{
|
||||
fn new(a: I, b: I) -> Self {
|
||||
MergeIterInner { a, b, peeked: None }
|
||||
}
|
||||
|
||||
fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
|
||||
let mut a_next = match self.peeked {
|
||||
Some(MergeIterPeeked::A(next)) => Some(next),
|
||||
_ => self.a.next(),
|
||||
};
|
||||
let mut b_next = match self.peeked {
|
||||
Some(MergeIterPeeked::B(next)) => Some(next),
|
||||
_ => self.b.next(),
|
||||
};
|
||||
let ord = match (a_next, b_next) {
|
||||
(None, None) => Equal,
|
||||
(_, None) => Less,
|
||||
(None, _) => Greater,
|
||||
(Some(a1), Some(b1)) => a1.cmp(&b1),
|
||||
};
|
||||
self.peeked = match ord {
|
||||
Less => b_next.take().map(MergeIterPeeked::B),
|
||||
Equal => None,
|
||||
Greater => a_next.take().map(MergeIterPeeked::A),
|
||||
};
|
||||
(a_next, b_next)
|
||||
}
|
||||
|
||||
fn lens(&self) -> (usize, usize) {
|
||||
match self.peeked {
|
||||
Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
|
||||
Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
|
||||
_ => (self.a.len(), self.b.len()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Debug for MergeIterInner<I>
|
||||
where I: Iterator + Debug,
|
||||
I::Item: Copy + Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("MergeIterInner")
|
||||
.field(&self.a)
|
||||
.field(&self.b)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// A lazy iterator producing elements in the difference of `BTreeSet`s.
|
||||
///
|
||||
/// This `struct` is created by the [`difference`] method on [`BTreeSet`].
|
||||
|
@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
|
|||
pub struct Difference<'a, T: 'a> {
|
||||
inner: DifferenceInner<'a, T>,
|
||||
}
|
||||
#[derive(Debug)]
|
||||
enum DifferenceInner<'a, T: 'a> {
|
||||
Stitch {
|
||||
// iterate all of self and some of other, spotting matches along the way
|
||||
|
@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
|
|||
#[stable(feature = "collection_debug", since = "1.17.0")]
|
||||
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.inner {
|
||||
DifferenceInner::Stitch {
|
||||
self_iter,
|
||||
other_iter,
|
||||
} => f
|
||||
.debug_tuple("Difference")
|
||||
.field(&self_iter)
|
||||
.field(&other_iter)
|
||||
.finish(),
|
||||
DifferenceInner::Search {
|
||||
self_iter,
|
||||
other_set: _,
|
||||
} => f.debug_tuple("Difference").field(&self_iter).finish(),
|
||||
DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
|
||||
}
|
||||
f.debug_tuple("Difference").field(&self.inner).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,18 +221,12 @@ impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
|
|||
/// [`BTreeSet`]: struct.BTreeSet.html
|
||||
/// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
pub struct SymmetricDifference<'a, T: 'a> {
|
||||
a: Peekable<Iter<'a, T>>,
|
||||
b: Peekable<Iter<'a, T>>,
|
||||
}
|
||||
pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
|
||||
|
||||
#[stable(feature = "collection_debug", since = "1.17.0")]
|
||||
impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("SymmetricDifference")
|
||||
.field(&self.a)
|
||||
.field(&self.b)
|
||||
.finish()
|
||||
f.debug_tuple("SymmetricDifference").field(&self.0).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,6 +241,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
|
|||
pub struct Intersection<'a, T: 'a> {
|
||||
inner: IntersectionInner<'a, T>,
|
||||
}
|
||||
#[derive(Debug)]
|
||||
enum IntersectionInner<'a, T: 'a> {
|
||||
Stitch {
|
||||
// iterate similarly sized sets jointly, spotting matches along the way
|
||||
|
@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
|
|||
#[stable(feature = "collection_debug", since = "1.17.0")]
|
||||
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.inner {
|
||||
IntersectionInner::Stitch {
|
||||
a,
|
||||
b,
|
||||
} => f
|
||||
.debug_tuple("Intersection")
|
||||
.field(&a)
|
||||
.field(&b)
|
||||
.finish(),
|
||||
IntersectionInner::Search {
|
||||
small_iter,
|
||||
large_set: _,
|
||||
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
|
||||
IntersectionInner::Answer(answer) => {
|
||||
f.debug_tuple("Intersection").field(&answer).finish()
|
||||
}
|
||||
}
|
||||
f.debug_tuple("Intersection").field(&self.inner).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,18 +271,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
|
|||
/// [`BTreeSet`]: struct.BTreeSet.html
|
||||
/// [`union`]: struct.BTreeSet.html#method.union
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
pub struct Union<'a, T: 'a> {
|
||||
a: Peekable<Iter<'a, T>>,
|
||||
b: Peekable<Iter<'a, T>>,
|
||||
}
|
||||
pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
|
||||
|
||||
#[stable(feature = "collection_debug", since = "1.17.0")]
|
||||
impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("Union")
|
||||
.field(&self.a)
|
||||
.field(&self.b)
|
||||
.finish()
|
||||
f.debug_tuple("Union").field(&self.0).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -355,19 +386,16 @@ impl<T: Ord> BTreeSet<T> {
|
|||
self_iter.next_back();
|
||||
DifferenceInner::Iterate(self_iter)
|
||||
}
|
||||
_ => {
|
||||
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
|
||||
DifferenceInner::Search {
|
||||
self_iter: self.iter(),
|
||||
other_set: other,
|
||||
}
|
||||
} else {
|
||||
DifferenceInner::Stitch {
|
||||
self_iter: self.iter(),
|
||||
other_iter: other.iter().peekable(),
|
||||
}
|
||||
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
|
||||
DifferenceInner::Search {
|
||||
self_iter: self.iter(),
|
||||
other_set: other,
|
||||
}
|
||||
}
|
||||
_ => DifferenceInner::Stitch {
|
||||
self_iter: self.iter(),
|
||||
other_iter: other.iter().peekable(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -396,10 +424,7 @@ impl<T: Ord> BTreeSet<T> {
|
|||
pub fn symmetric_difference<'a>(&'a self,
|
||||
other: &'a BTreeSet<T>)
|
||||
-> SymmetricDifference<'a, T> {
|
||||
SymmetricDifference {
|
||||
a: self.iter().peekable(),
|
||||
b: other.iter().peekable(),
|
||||
}
|
||||
SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
|
||||
}
|
||||
|
||||
/// Visits the values representing the intersection,
|
||||
|
@ -447,24 +472,22 @@ impl<T: Ord> BTreeSet<T> {
|
|||
(Greater, _) | (_, Less) => IntersectionInner::Answer(None),
|
||||
(Equal, _) => IntersectionInner::Answer(Some(self_min)),
|
||||
(_, Equal) => IntersectionInner::Answer(Some(self_max)),
|
||||
_ => {
|
||||
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
|
||||
IntersectionInner::Search {
|
||||
small_iter: self.iter(),
|
||||
large_set: other,
|
||||
}
|
||||
} else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
|
||||
IntersectionInner::Search {
|
||||
small_iter: other.iter(),
|
||||
large_set: self,
|
||||
}
|
||||
} else {
|
||||
IntersectionInner::Stitch {
|
||||
a: self.iter(),
|
||||
b: other.iter(),
|
||||
}
|
||||
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
|
||||
IntersectionInner::Search {
|
||||
small_iter: self.iter(),
|
||||
large_set: other,
|
||||
}
|
||||
}
|
||||
_ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
|
||||
IntersectionInner::Search {
|
||||
small_iter: other.iter(),
|
||||
large_set: self,
|
||||
}
|
||||
}
|
||||
_ => IntersectionInner::Stitch {
|
||||
a: self.iter(),
|
||||
b: other.iter(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -489,10 +512,7 @@ impl<T: Ord> BTreeSet<T> {
|
|||
/// ```
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
|
||||
Union {
|
||||
a: self.iter().peekable(),
|
||||
b: other.iter().peekable(),
|
||||
}
|
||||
Union(MergeIterInner::new(self.iter(), other.iter()))
|
||||
}
|
||||
|
||||
/// Clears the set, removing all values.
|
||||
|
@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> {
|
|||
#[stable(feature = "fused", since = "1.26.0")]
|
||||
impl<T> FusedIterator for Range<'_, T> {}
|
||||
|
||||
/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
|
||||
fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
|
||||
match (x, y) {
|
||||
(None, _) => short,
|
||||
(_, None) => long,
|
||||
(Some(x1), Some(y1)) => x1.cmp(y1),
|
||||
}
|
||||
}
|
||||
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
impl<T> Clone for Difference<'_, T> {
|
||||
fn clone(&self) -> Self {
|
||||
|
@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
|
|||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
impl<T> Clone for SymmetricDifference<'_, T> {
|
||||
fn clone(&self) -> Self {
|
||||
SymmetricDifference {
|
||||
a: self.a.clone(),
|
||||
b: self.b.clone(),
|
||||
}
|
||||
SymmetricDifference(self.0.clone())
|
||||
}
|
||||
}
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
|
@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
|
|||
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
loop {
|
||||
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
|
||||
Less => return self.a.next(),
|
||||
Equal => {
|
||||
self.a.next();
|
||||
self.b.next();
|
||||
}
|
||||
Greater => return self.b.next(),
|
||||
let (a_next, b_next) = self.0.nexts();
|
||||
if a_next.and(b_next).is_none() {
|
||||
return a_next.or(b_next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(0, Some(self.a.len() + self.b.len()))
|
||||
let (a_len, b_len) = self.0.lens();
|
||||
// No checked_add, because even if a and b refer to the same set,
|
||||
// and T is an empty type, the storage overhead of sets limits
|
||||
// the number of elements to less than half the range of usize.
|
||||
(0, Some(a_len + b_len))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1311,7 +1319,7 @@ impl<T> Clone for Intersection<'_, T> {
|
|||
small_iter: small_iter.clone(),
|
||||
large_set,
|
||||
},
|
||||
IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
|
||||
IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -1365,10 +1373,7 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
|
|||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
impl<T> Clone for Union<'_, T> {
|
||||
fn clone(&self) -> Self {
|
||||
Union {
|
||||
a: self.a.clone(),
|
||||
b: self.b.clone(),
|
||||
}
|
||||
Union(self.0.clone())
|
||||
}
|
||||
}
|
||||
#[stable(feature = "rust1", since = "1.0.0")]
|
||||
|
@ -1376,19 +1381,13 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
|
|||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<&'a T> {
|
||||
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
|
||||
Less => self.a.next(),
|
||||
Equal => {
|
||||
self.b.next();
|
||||
self.a.next()
|
||||
}
|
||||
Greater => self.b.next(),
|
||||
}
|
||||
let (a_next, b_next) = self.0.nexts();
|
||||
a_next.or(b_next)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let a_len = self.a.len();
|
||||
let b_len = self.b.len();
|
||||
let (a_len, b_len) = self.0.lens();
|
||||
// No checked_add - see SymmetricDifference::size_hint.
|
||||
(max(a_len, b_len), Some(a_len + b_len))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -221,6 +221,18 @@ fn test_symmetric_difference() {
|
|||
&[-2, 1, 5, 11, 14, 22]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symmetric_difference_size_hint() {
|
||||
let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
|
||||
let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
|
||||
let mut iter = x.symmetric_difference(&y);
|
||||
assert_eq!(iter.size_hint(), (0, Some(5)));
|
||||
assert_eq!(iter.next(), Some(&1));
|
||||
assert_eq!(iter.size_hint(), (0, Some(4)));
|
||||
assert_eq!(iter.next(), Some(&3));
|
||||
assert_eq!(iter.size_hint(), (0, Some(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_union() {
|
||||
fn check_union(a: &[i32], b: &[i32], expected: &[i32]) {
|
||||
|
@ -235,6 +247,18 @@ fn test_union() {
|
|||
&[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_union_size_hint() {
|
||||
let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
|
||||
let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
|
||||
let mut iter = x.union(&y);
|
||||
assert_eq!(iter.size_hint(), (3, Some(5)));
|
||||
assert_eq!(iter.next(), Some(&1));
|
||||
assert_eq!(iter.size_hint(), (2, Some(4)));
|
||||
assert_eq!(iter.next(), Some(&2));
|
||||
assert_eq!(iter.size_hint(), (1, Some(2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Only tests the simple function definition with respect to intersection
|
||||
fn test_is_disjoint() {
|
||||
|
@ -244,7 +268,7 @@ fn test_is_disjoint() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
// Also tests the trivial function definition of is_superset
|
||||
// Also implicitly tests the trivial function definition of is_superset
|
||||
fn test_is_subset() {
|
||||
fn is_subset(a: &[i32], b: &[i32]) -> bool {
|
||||
let set_a = a.iter().collect::<BTreeSet<_>>();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue