1
Fork 0

auto merge of #17558 : kaseyc/rust/fix_bitvset_union, r=aturon

Updates the other_op function shared by the union/intersect/difference/symmetric_difference -with functions to fix an issue where certain elements would not be present in the result. To fix this, when other op is called, we resize self's nbits to account for any new elements that may be added to the set.

Example:
```rust
	let mut a = BitvSet::new();
	let mut b = BitvSet::new();
	a.insert(0);
	b.insert(5);
	a.union_with(&b);
	println!("{}", a); //Prints "{0}" instead of "{0, 5}"
```
This commit is contained in:
bors 2014-10-09 19:02:06 +00:00
commit 79d056f94b

View file

@ -235,11 +235,20 @@ impl Bitv {
/// }
/// ```
pub fn with_capacity(nbits: uint, init: bool) -> Bitv {
Bitv {
let mut bitv = Bitv {
storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
if init { !0u } else { 0u }),
nbits: nbits
};
// Zero out any unused bits in the highest word if necessary
let used_bits = bitv.nbits % uint::BITS;
if init && used_bits != 0 {
let largest_used_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(largest_used_word) &= (1 << used_bits) - 1;
}
bitv
}
/// Retrieves the value at index `i`.
@ -629,9 +638,9 @@ impl Bitv {
/// ```
pub fn reserve(&mut self, size: uint) {
let old_size = self.storage.len();
let size = (size + uint::BITS - 1) / uint::BITS;
if old_size < size {
self.storage.grow(size - old_size, 0);
let new_size = (size + uint::BITS - 1) / uint::BITS;
if old_size < new_size {
self.storage.grow(new_size - old_size, 0);
}
}
@ -686,8 +695,15 @@ impl Bitv {
}
// Allocate new words, if needed
if new_nwords > self.storage.len() {
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);
// Zero out and unused bits in the new tail word
if value {
let tail_word = new_nwords - 1;
let used_bits = new_nbits % uint::BITS;
*self.storage.get_mut(tail_word) &= (1 << used_bits) - 1;
}
}
// Adjust internal bit count
self.nbits = new_nbits;
@ -970,9 +986,8 @@ impl<'a> RandomAccessIterator<bool> for Bits<'a> {
/// }
///
/// // Can convert back to a `Bitv`
/// let bv: Bitv = s.unwrap();
/// assert!(bv.eq_vec([true, true, false, true,
/// false, false, false, false]));
/// let bv: Bitv = s.into_bitv();
/// assert!(bv.get(3));
/// ```
#[deriving(Clone)]
pub struct BitvSet(Bitv);
@ -993,7 +1008,8 @@ impl FromIterator<bool> for BitvSet {
impl Extendable<bool> for BitvSet {
#[inline]
fn extend<I: Iterator<bool>>(&mut self, iterator: I) {
self.get_mut_ref().extend(iterator);
let &BitvSet(ref mut self_bitv) = self;
self_bitv.extend(iterator);
}
}
@ -1049,7 +1065,8 @@ impl BitvSet {
/// ```
#[inline]
pub fn with_capacity(nbits: uint) -> BitvSet {
BitvSet(Bitv::with_capacity(nbits, false))
let bitv = Bitv::with_capacity(nbits, false);
BitvSet::from_bitv(bitv)
}
/// Creates a new bit vector set from the given bit vector.
@ -1068,7 +1085,9 @@ impl BitvSet {
/// }
/// ```
#[inline]
pub fn from_bitv(bitv: Bitv) -> BitvSet {
pub fn from_bitv(mut bitv: Bitv) -> BitvSet {
// Mark every bit as valid
bitv.nbits = bitv.capacity();
BitvSet(bitv)
}
@ -1102,7 +1121,10 @@ impl BitvSet {
/// ```
pub fn reserve(&mut self, size: uint) {
let &BitvSet(ref mut bitv) = self;
bitv.reserve(size)
bitv.reserve(size);
if bitv.nbits < size {
bitv.nbits = bitv.capacity();
}
}
/// Consumes this set to return the underlying bit vector.
@ -1116,11 +1138,12 @@ impl BitvSet {
/// s.insert(0);
/// s.insert(3);
///
/// let bv = s.unwrap();
/// assert!(bv.eq_vec([true, false, false, true]));
/// let bv = s.into_bitv();
/// assert!(bv.get(0));
/// assert!(bv.get(3));
/// ```
#[inline]
pub fn unwrap(self) -> Bitv {
pub fn into_bitv(self) -> Bitv {
let BitvSet(bitv) = self;
bitv
}
@ -1144,38 +1167,15 @@ impl BitvSet {
bitv
}
/// Returns a mutable reference to the underlying bit vector.
///
/// # Example
///
/// ```
/// use std::collections::BitvSet;
///
/// let mut s = BitvSet::new();
/// s.insert(0);
/// assert_eq!(s.contains(&0), true);
/// {
/// // Will free the set during bv's lifetime
/// let bv = s.get_mut_ref();
/// bv.set(0, false);
/// }
/// assert_eq!(s.contains(&0), false);
/// ```
#[inline]
pub fn get_mut_ref<'a>(&'a mut self) -> &'a mut Bitv {
let &BitvSet(ref mut bitv) = self;
bitv
}
#[inline]
fn other_op(&mut self, other: &BitvSet, f: |uint, uint| -> uint) {
// Expand the vector if necessary
self.reserve(other.capacity());
// Unwrap Bitvs
let &BitvSet(ref mut self_bitv) = self;
let &BitvSet(ref other_bitv) = other;
// Expand the vector if necessary
self_bitv.reserve(other_bitv.capacity());
// virtually pad other with 0's for equal lengths
let mut other_words = {
let (_, result) = match_words(self_bitv, other_bitv);
@ -1376,9 +1376,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.union_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn union_with(&mut self, other: &BitvSet) {
@ -1399,9 +1400,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.intersect_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn intersect_with(&mut self, other: &BitvSet) {
@ -1424,15 +1426,17 @@ impl BitvSet {
///
/// let mut bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let bva_b = BitvSet::from_bitv(bitv::from_bytes([a_b]));
/// let bvb_a = BitvSet::from_bitv(bitv::from_bytes([b_a]));
///
/// bva.difference_with(&bvb);
/// assert_eq!(bva.unwrap(), bitv::from_bytes([a_b]));
/// assert_eq!(bva, bva_b);
///
/// let bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let mut bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
///
/// bvb.difference_with(&bva);
/// assert_eq!(bvb.unwrap(), bitv::from_bytes([b_a]));
/// assert_eq!(bvb, bvb_a);
/// ```
#[inline]
pub fn difference_with(&mut self, other: &BitvSet) {
@ -1454,9 +1458,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.symmetric_difference_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn symmetric_difference_with(&mut self, other: &BitvSet) {
@ -1538,20 +1543,14 @@ impl MutableSet<uint> for BitvSet {
if self.contains(&value) {
return false;
}
// Ensure we have enough space to hold the new element
if value >= self.capacity() {
let new_cap = cmp::max(value + 1, self.capacity() * 2);
self.reserve(new_cap);
}
let &BitvSet(ref mut bitv) = self;
if value >= bitv.nbits {
// If we are increasing nbits, make sure we mask out any previously-unconsidered bits
let old_rem = bitv.nbits % uint::BITS;
if old_rem != 0 {
let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
}
bitv.nbits = value + 1;
}
bitv.set(value, true);
return true;
}
@ -2225,6 +2224,7 @@ mod tests {
assert!(a.insert(160));
assert!(a.insert(19));
assert!(a.insert(24));
assert!(a.insert(200));
assert!(b.insert(1));
assert!(b.insert(5));
@ -2232,7 +2232,7 @@ mod tests {
assert!(b.insert(13));
assert!(b.insert(19));
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160];
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160, 200];
let actual = a.union(&b).collect::<Vec<uint>>();
assert_eq!(actual.as_slice(), expected.as_slice());
}
@ -2281,6 +2281,27 @@ mod tests {
assert!(c.is_disjoint(&b))
}
#[test]
fn test_bitv_set_union_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
let mut b = BitvSet::new();
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.union_with(&b);
assert_eq!(a, expected);
// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.union_with(&b);
b.union_with(&c);
assert_eq!(a.len(), 4);
assert_eq!(b.len(), 4);
}
#[test]
fn test_bitv_set_intersect_with() {
// Explicitly 0'ed bits
@ -2311,6 +2332,59 @@ mod tests {
assert_eq!(b.len(), 2);
}
#[test]
fn test_bitv_set_difference_with() {
// Explicitly 0'ed bits
let mut a = BitvSet::from_bitv(from_bytes([0b00000000]));
let b = BitvSet::from_bitv(from_bytes([0b10100010]));
a.difference_with(&b);
assert!(a.is_empty());
// Uninitialized bits should behave like 0's
let mut a = BitvSet::new();
let b = BitvSet::from_bitv(from_bytes([0b11111111]));
a.difference_with(&b);
assert!(a.is_empty());
// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.difference_with(&b);
b.difference_with(&c);
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
}
#[test]
fn test_bitv_set_symmetric_difference_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
a.insert(1);
let mut b = BitvSet::new();
b.insert(1);
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.symmetric_difference_with(&b);
assert_eq!(a, expected);
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let b = BitvSet::new();
let c = a.clone();
a.symmetric_difference_with(&b);
assert_eq!(a, c);
// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b11100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01101010]));
let c = a.clone();
a.symmetric_difference_with(&b);
b.symmetric_difference_with(&c);
assert_eq!(a.len(), 2);
assert_eq!(b.len(), 2);
}
#[test]
fn test_bitv_set_eq() {
let a = BitvSet::from_bitv(from_bytes([0b10100010]));