diff --git a/src/lib/map.rs b/src/lib/map.rs index ced31513b63..ce4f065f64d 100644 --- a/src/lib/map.rs +++ b/src/lib/map.rs @@ -13,6 +13,7 @@ type hashfn[K] = fn(&K) -> uint; type eqfn[K] = fn(&K, &K) -> bool; type hashmap[K, V] = obj { + fn size() -> uint; fn insert(&K key, &V val) -> bool; fn contains_key(&K key) -> bool; fn get(&K key) -> V; @@ -141,6 +142,8 @@ fn mk_hashmap[K, V](&hashfn[K] hasher, &eqfn[K] eqer) -> hashmap[K, V] { mutable uint nelts, util.rational lf) { + fn size() -> uint { ret nelts; } + fn insert(&K key, &V val) -> bool { let util.rational load = rec(num=(nelts + 1u) as int, den=nbkts as int); if (!util.rational_leq(load, lf)) { @@ -181,17 +184,19 @@ fn mk_hashmap[K, V](&hashfn[K] hasher, &eqfn[K] eqer) -> hashmap[K, V] { while (i < nbkts) { let uint j = (hash[K](hasher, nbkts, key, i)); alt (bkts.(j)) { - case (some[K, V](_, val)) { - bkts.(j) = deleted[K, V](); - ret util.some[V](val); - } - case (deleted[K, V]()) { - nelts += 1u; + case (some[K, V](k, v)) { + if (eqer(key, k)) { + bkts.(j) = deleted[K, V](); + nelts -= 1u; + ret util.some[V](v); + } } + case (deleted[K, V]()) { } case (nil[K, V]()) { ret util.none[V](); } } + i += 1u; } ret util.none[V](); } diff --git a/src/test/run-pass/lib-map.rs b/src/test/run-pass/lib-map.rs index 51285c067c4..a59bdf0002a 100644 --- a/src/test/run-pass/lib-map.rs +++ b/src/test/run-pass/lib-map.rs @@ -2,6 +2,7 @@ use std; import std.map; +import std.util; fn test_simple() { log "*** starting test_simple"; @@ -17,18 +18,18 @@ fn test_simple() { let map.eqfn[uint] eqer = eq; let map.hashmap[uint, uint] hm = map.mk_hashmap[uint, uint](hasher, eqer); - hm.insert(10u, 12u); - hm.insert(11u, 13u); - hm.insert(12u, 14u); + check (hm.insert(10u, 12u)); + check (hm.insert(11u, 13u)); + check (hm.insert(12u, 14u)); check (hm.get(11u) == 13u); check (hm.get(12u) == 14u); check (hm.get(10u) == 12u); - hm.insert(12u, 14u); + check (!hm.insert(12u, 14u)); check (hm.get(12u) == 14u); - hm.insert(12u, 12u); + check (!hm.insert(12u, 12u)); check (hm.get(12u) == 12u); log "*** finished test_simple"; @@ -55,7 +56,7 @@ fn test_growth() { let uint i = 0u; while (i < num_to_insert) { - hm.insert(i, i * i); + check (hm.insert(i, i * i)); log "inserting " + std._uint.to_str(i, 10u) + " -> " + std._uint.to_str(i * i, 10u); i += 1u; @@ -71,7 +72,7 @@ fn test_growth() { i += 1u; } - hm.insert(num_to_insert, 17u); + check (hm.insert(num_to_insert, 17u)); check (hm.get(num_to_insert) == 17u); log "-----"; @@ -89,7 +90,128 @@ fn test_growth() { log "*** finished test_growth"; } +fn test_removal() { + log "*** starting test_removal"; + + let uint num_to_insert = 64u; + + fn eq(&uint x, &uint y) -> bool { ret x == y; } + fn hash(&uint u) -> uint { + // This hash function intentionally causes collisions between + // consecutive integer pairs. + ret (u / 2u) * 2u; + } + + let map.hashfn[uint] hasher = hash; + let map.eqfn[uint] eqer = eq; + let map.hashmap[uint, uint] hm = map.mk_hashmap[uint, uint](hasher, eqer); + + let uint i = 0u; + while (i < num_to_insert) { + check (hm.insert(i, i * i)); + log "inserting " + std._uint.to_str(i, 10u) + + " -> " + std._uint.to_str(i * i, 10u); + i += 1u; + } + + check (hm.size() == num_to_insert); + + log "-----"; + log "removing evens"; + + i = 0u; + while (i < num_to_insert) { + /** + * FIXME (issue #150): we want to check the removed value as in the + * following: + + let util.option[uint] v = hm.remove(i); + alt (v) { + case (util.some[uint](u)) { + check (u == (i * i)); + } + case (util.none[uint]()) { fail; } + } + + * but we util.option is a tag type so util.some and util.none are + * off limits until we parse the dwarf for tag types. + */ + + hm.remove(i); + i += 2u; + } + + check (hm.size() == (num_to_insert / 2u)); + + log "-----"; + + i = 1u; + while (i < num_to_insert) { + log "get(" + std._uint.to_str(i, 10u) + ") = " + + std._uint.to_str(hm.get(i), 10u); + check (hm.get(i) == i * i); + i += 2u; + } + + log "-----"; + log "rehashing"; + + hm.rehash(); + + log "-----"; + + i = 1u; + while (i < num_to_insert) { + log "get(" + std._uint.to_str(i, 10u) + ") = " + + std._uint.to_str(hm.get(i), 10u); + check (hm.get(i) == i * i); + i += 2u; + } + + log "-----"; + + i = 0u; + while (i < num_to_insert) { + check (hm.insert(i, i * i)); + log "inserting " + std._uint.to_str(i, 10u) + + " -> " + std._uint.to_str(i * i, 10u); + i += 2u; + } + + check (hm.size() == num_to_insert); + + log "-----"; + + i = 0u; + while (i < num_to_insert) { + log "get(" + std._uint.to_str(i, 10u) + ") = " + + std._uint.to_str(hm.get(i), 10u); + check (hm.get(i) == i * i); + i += 1u; + } + + log "-----"; + log "rehashing"; + + hm.rehash(); + + log "-----"; + + check (hm.size() == num_to_insert); + + i = 0u; + while (i < num_to_insert) { + log "get(" + std._uint.to_str(i, 10u) + ") = " + + std._uint.to_str(hm.get(i), 10u); + check (hm.get(i) == i * i); + i += 1u; + } + + log "*** finished test_removal"; +} + fn main() { test_simple(); test_growth(); + test_removal(); }