improve worst-case performance of BTreeSet difference and intersection

This commit is contained in:
Stein Somers 2019-03-13 23:01:12 +01:00
parent 4fec737f9a
commit f5fee8fd7d
3 changed files with 352 additions and 123 deletions

View file

@ -3,59 +3,49 @@ use std::collections::BTreeSet;
use rand::{thread_rng, Rng};
use test::{black_box, Bencher};
fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] {
fn random(n: usize) -> BTreeSet<usize> {
let mut rng = thread_rng();
let mut set1 = BTreeSet::new();
let mut set2 = BTreeSet::new();
for _ in 0..n1 {
let i = rng.gen::<usize>();
set1.insert(i);
let mut set = BTreeSet::new();
while set.len() < n {
set.insert(rng.gen());
}
for _ in 0..n2 {
let i = rng.gen::<usize>();
set2.insert(i);
}
[set1, set2]
assert_eq!(set.len(), n);
set
}
fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] {
let mut even = BTreeSet::new();
let mut odd = BTreeSet::new();
for i in 0..n1 {
even.insert(i * 2);
fn neg(n: usize) -> BTreeSet<i32> {
let mut set = BTreeSet::new();
for i in -(n as i32)..=-1 {
set.insert(i);
}
for i in 0..n2 {
odd.insert(i * 2 + 1);
}
[even, odd]
assert_eq!(set.len(), n);
set
}
fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
let mut neg = BTreeSet::new();
let mut pos = BTreeSet::new();
for i in -(n1 as i32)..=-1 {
neg.insert(i);
fn pos(n: usize) -> BTreeSet<i32> {
let mut set = BTreeSet::new();
for i in 1..=(n as i32) {
set.insert(i);
}
for i in 1..=(n2 as i32) {
pos.insert(i);
}
[neg, pos]
assert_eq!(set.len(), n);
set
}
fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
let mut neg = BTreeSet::new();
let mut pos = BTreeSet::new();
for i in -(n1 as i32)..=-1 {
neg.insert(i);
fn stagger(n1: usize, factor: usize) -> [BTreeSet<u32>; 2] {
let n2 = n1 * factor;
let mut sets = [BTreeSet::new(), BTreeSet::new()];
for i in 0..(n1 + n2) {
let b = i % (factor + 1) != 0;
sets[b as usize].insert(i as u32);
}
for i in 1..=(n2 as i32) {
pos.insert(i);
}
[pos, neg]
assert_eq!(sets[0].len(), n1);
assert_eq!(sets[1].len(), n2);
sets
}
macro_rules! set_intersection_bench {
($name: ident, $sets: expr) => {
macro_rules! set_bench {
($name: ident, $set_func: ident, $result_func: ident, $sets: expr) => {
#[bench]
pub fn $name(b: &mut Bencher) {
// setup
@ -63,26 +53,36 @@ macro_rules! set_intersection_bench {
// measure
b.iter(|| {
let x = sets[0].intersection(&sets[1]).count();
let x = sets[0].$set_func(&sets[1]).$result_func();
black_box(x);
})
}
};
}
set_intersection_bench! {intersect_random_100, random(100, 100)}
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)}
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)}
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)}
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)}
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)}
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)}
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)}
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)}
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)}
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)}
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)}
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)}
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)}
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)}
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)}
set_bench! {intersection_100_neg_vs_100_pos, intersection, count, [neg(100), pos(100)]}
set_bench! {intersection_100_neg_vs_10k_pos, intersection, count, [neg(100), pos(10_000)]}
set_bench! {intersection_100_pos_vs_100_neg, intersection, count, [pos(100), neg(100)]}
set_bench! {intersection_100_pos_vs_10k_neg, intersection, count, [pos(100), neg(10_000)]}
set_bench! {intersection_10k_neg_vs_100_pos, intersection, count, [neg(10_000), pos(100)]}
set_bench! {intersection_10k_neg_vs_10k_pos, intersection, count, [neg(10_000), pos(10_000)]}
set_bench! {intersection_10k_pos_vs_100_neg, intersection, count, [pos(10_000), neg(100)]}
set_bench! {intersection_10k_pos_vs_10k_neg, intersection, count, [pos(10_000), neg(10_000)]}
set_bench! {intersection_random_100_vs_100, intersection, count, [random(100), random(100)]}
set_bench! {intersection_random_100_vs_10k, intersection, count, [random(100), random(10_000)]}
set_bench! {intersection_random_10k_vs_100, intersection, count, [random(10_000), random(100)]}
set_bench! {intersection_random_10k_vs_10k, intersection, count, [random(10_000), random(10_000)]}
set_bench! {intersection_staggered_100_vs_100, intersection, count, stagger(100, 1)}
set_bench! {intersection_staggered_10k_vs_10k, intersection, count, stagger(10_000, 1)}
set_bench! {intersection_staggered_100_vs_10k, intersection, count, stagger(100, 100)}
set_bench! {difference_random_100_vs_100, difference, count, [random(100), random(100)]}
set_bench! {difference_random_100_vs_10k, difference, count, [random(100), random(10_000)]}
set_bench! {difference_random_10k_vs_100, difference, count, [random(10_000), random(100)]}
set_bench! {difference_random_10k_vs_10k, difference, count, [random(10_000), random(10_000)]}
set_bench! {difference_staggered_100_vs_100, difference, count, stagger(100, 1)}
set_bench! {difference_staggered_10k_vs_10k, difference, count, stagger(10_000, 1)}
set_bench! {difference_staggered_100_vs_10k, difference, count, stagger(100, 100)}
set_bench! {is_subset_100_vs_100, is_subset, clone, [pos(100), pos(100)]}
set_bench! {is_subset_100_vs_10k, is_subset, clone, [pos(100), pos(10_000)]}
set_bench! {is_subset_10k_vs_100, is_subset, clone, [pos(10_000), pos(100)]}
set_bench! {is_subset_10k_vs_10k, is_subset, clone, [pos(10_000), pos(10_000)]}

View file

@ -3,7 +3,7 @@
use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::{min, max};
use core::cmp::max;
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@ -118,17 +118,36 @@ pub struct Range<'a, T: 'a> {
/// [`difference`]: struct.BTreeSet.html#method.difference
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Difference<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
inner: DifferenceInner<'a, T>,
}
enum DifferenceInner<'a, T: 'a> {
Stitch {
self_iter: Iter<'a, T>,
other_iter: Peekable<Iter<'a, T>>,
},
Search {
self_iter: Iter<'a, T>,
other_set: &'a BTreeSet<T>,
},
}
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Difference")
.field(&self.a)
.field(&self.b)
.finish()
match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => f
.debug_tuple("Difference")
.field(&self_iter)
.field(&other_iter)
.finish(),
DifferenceInner::Search {
self_iter,
other_set: _,
} => f.debug_tuple("Difference").field(&self_iter).finish(),
}
}
}
@ -164,17 +183,36 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
/// [`intersection`]: struct.BTreeSet.html#method.intersection
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Intersection<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
inner: IntersectionInner<'a, T>,
}
enum IntersectionInner<'a, T: 'a> {
Stitch {
small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
other_iter: Iter<'a, T>,
},
Search {
small_iter: Iter<'a, T>,
large_set: &'a BTreeSet<T>,
},
}
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Intersection")
.field(&self.a)
.field(&self.b)
.finish()
match &self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
} => f
.debug_tuple("Intersection")
.field(&small_iter)
.field(&other_iter)
.finish(),
IntersectionInner::Search {
small_iter,
large_set: _,
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
}
}
}
@ -201,6 +239,14 @@ impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
}
}
// This constant is used by functions that compare two sets.
// It estimates the relative size at which searching performs better
// than iterating, based on the benchmarks in
// https://github.com/ssomers/rust_bench_btreeset_intersection;
// It's used to divide rather than multiply sizes, to rule out overflow,
// and it's a power of two to make that division cheap.
const ITER_PERFORMANCE_TIPPING_SIZE_DIFF: usize = 16;
impl<T: Ord> BTreeSet<T> {
/// Makes a new `BTreeSet` with a reasonable choice of B.
///
@ -268,9 +314,24 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn difference<'a>(&'a self, other: &'a BTreeSet<T>) -> Difference<'a, T> {
Difference {
a: self.iter().peekable(),
b: other.iter().peekable(),
if self.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
// Self is bigger than or not much smaller than other set.
// Iterate both sets jointly, spotting matches along the way.
Difference {
inner: DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
},
}
} else {
// Self is much smaller than other set, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
Difference {
inner: DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
},
}
}
}
@ -326,9 +387,29 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
Intersection {
a: self.iter().peekable(),
b: other.iter().peekable(),
let (small, other) = if self.len() <= other.len() {
(self, other)
} else {
(other, self)
};
if small.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
// Small set is not much smaller than other set.
// Iterate both sets jointly, spotting matches along the way.
Intersection {
inner: IntersectionInner::Stitch {
small_iter: small.iter(),
other_iter: other.iter(),
},
}
} else {
// Big difference in number of elements, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
Intersection {
inner: IntersectionInner::Search {
small_iter: small.iter(),
large_set: other,
},
}
}
}
@ -462,28 +543,44 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn is_subset(&self, other: &BTreeSet<T>) -> bool {
// Stolen from TreeMap
let mut x = self.iter();
let mut y = other.iter();
let mut a = x.next();
let mut b = y.next();
while a.is_some() {
if b.is_none() {
return false;
// Same result as self.difference(other).next().is_none()
// but the 3 paths below are faster (in order: hugely, 20%, 5%).
if self.len() > other.len() {
false
} else if self.len() > other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
// Self is not much smaller than other set.
// Stolen from TreeMap
let mut x = self.iter();
let mut y = other.iter();
let mut a = x.next();
let mut b = y.next();
while a.is_some() {
if b.is_none() {
return false;
}
let a1 = a.unwrap();
let b1 = b.unwrap();
match b1.cmp(a1) {
Less => (),
Greater => return false,
Equal => a = x.next(),
}
b = y.next();
}
let a1 = a.unwrap();
let b1 = b.unwrap();
match b1.cmp(a1) {
Less => (),
Greater => return false,
Equal => a = x.next(),
true
} else {
// Big difference in number of elements, or both sets are empty.
// Iterate the small set, searching for matches in the large set.
for next in self {
if !other.contains(next) {
return false;
}
}
b = y.next();
true
}
true
}
/// Returns `true` if the set is a superset of another,
@ -1001,8 +1098,22 @@ fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering
impl<T> Clone for Difference<'_, T> {
fn clone(&self) -> Self {
Difference {
a: self.a.clone(),
b: self.b.clone(),
inner: match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => DifferenceInner::Stitch {
self_iter: self_iter.clone(),
other_iter: other_iter.clone(),
},
DifferenceInner::Search {
self_iter,
other_set,
} => DifferenceInner::Search {
self_iter: self_iter.clone(),
other_set,
},
},
}
}
}
@ -1011,24 +1122,52 @@ impl<'a, T: Ord> Iterator for Difference<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) {
Less => return self.a.next(),
Equal => {
self.a.next();
self.b.next();
}
Greater => {
self.b.next();
match &mut self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => {
let mut self_next = self_iter.next()?;
loop {
match other_iter
.peek()
.map_or(Less, |other_next| Ord::cmp(self_next, other_next))
{
Less => return Some(self_next),
Equal => {
self_next = self_iter.next()?;
other_iter.next();
}
Greater => {
other_iter.next();
}
}
}
}
DifferenceInner::Search {
self_iter,
other_set,
} => loop {
let self_next = self_iter.next()?;
if !other_set.contains(&self_next) {
return Some(self_next);
}
},
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let a_len = self.a.len();
let b_len = self.b.len();
(a_len.saturating_sub(b_len), Some(a_len))
let (self_len, other_len) = match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter
} => (self_iter.len(), other_iter.len()),
DifferenceInner::Search {
self_iter,
other_set
} => (self_iter.len(), other_set.len()),
};
(self_len.saturating_sub(other_len), Some(self_len))
}
}
@ -1073,8 +1212,22 @@ impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
impl<T> Clone for Intersection<'_, T> {
fn clone(&self) -> Self {
Intersection {
a: self.a.clone(),
b: self.b.clone(),
inner: match &self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
} => IntersectionInner::Stitch {
small_iter: small_iter.clone(),
other_iter: other_iter.clone(),
},
IntersectionInner::Search {
small_iter,
large_set,
} => IntersectionInner::Search {
small_iter: small_iter.clone(),
large_set,
},
},
}
}
}
@ -1083,24 +1236,39 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
loop {
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
Less => {
self.a.next();
}
Equal => {
self.b.next();
return self.a.next();
}
Greater => {
self.b.next();
match &mut self.inner {
IntersectionInner::Stitch {
small_iter,
other_iter,
} => {
let mut small_next = small_iter.next()?;
let mut other_next = other_iter.next()?;
loop {
match Ord::cmp(small_next, other_next) {
Less => small_next = small_iter.next()?,
Greater => other_next = other_iter.next()?,
Equal => return Some(small_next),
}
}
}
IntersectionInner::Search {
small_iter,
large_set,
} => loop {
let small_next = small_iter.next()?;
if large_set.contains(&small_next) {
return Some(small_next);
}
},
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(min(self.a.len(), self.b.len())))
let min_len = match &self.inner {
IntersectionInner::Stitch { small_iter, .. } => small_iter.len(),
IntersectionInner::Search { small_iter, .. } => small_iter.len(),
};
(0, Some(min_len))
}
}

View file

@ -69,6 +69,20 @@ fn test_intersection() {
check_intersection(&[11, 1, 3, 77, 103, 5, -5],
&[2, 11, 77, -9, -42, 5, 3],
&[3, 5, 11, 77]);
let large = (0..1000).collect::<Vec<_>>();
check_intersection(&[], &large, &[]);
check_intersection(&large, &[], &[]);
check_intersection(&[-1], &large, &[]);
check_intersection(&large, &[-1], &[]);
check_intersection(&[0], &large, &[0]);
check_intersection(&large, &[0], &[0]);
check_intersection(&[999], &large, &[999]);
check_intersection(&large, &[999], &[999]);
check_intersection(&[1000], &large, &[]);
check_intersection(&large, &[1000], &[]);
check_intersection(&[11, 5000, 1, 3, 77, 8924, 103],
&large,
&[1, 3, 11, 77, 103]);
}
#[test]
@ -84,6 +98,18 @@ fn test_difference() {
check_difference(&[-5, 11, 22, 33, 40, 42],
&[-12, -5, 14, 23, 34, 38, 39, 50],
&[11, 22, 33, 40, 42]);
let large = (0..1000).collect::<Vec<_>>();
check_difference(&[], &large, &[]);
check_difference(&[-1], &large, &[-1]);
check_difference(&[0], &large, &[]);
check_difference(&[999], &large, &[]);
check_difference(&[1000], &large, &[1000]);
check_difference(&[11, 5000, 1, 3, 77, 8924, 103],
&large,
&[5000, 8924]);
check_difference(&large, &[], &large);
check_difference(&large, &[-1], &large);
check_difference(&large, &[1000], &large);
}
#[test]
@ -114,6 +140,41 @@ fn test_union() {
&[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]);
}
#[test]
// Only tests the simple function definition with respect to intersection
fn test_is_disjoint() {
let one = [1].into_iter().collect::<BTreeSet<_>>();
let two = [2].into_iter().collect::<BTreeSet<_>>();
assert!(one.is_disjoint(&two));
}
#[test]
// Also tests the trivial function definition of is_superset
fn test_is_subset() {
fn is_subset(a: &[i32], b: &[i32]) -> bool {
let set_a = a.iter().collect::<BTreeSet<_>>();
let set_b = b.iter().collect::<BTreeSet<_>>();
set_a.is_subset(&set_b)
}
assert_eq!(is_subset(&[], &[]), true);
assert_eq!(is_subset(&[], &[1, 2]), true);
assert_eq!(is_subset(&[0], &[1, 2]), false);
assert_eq!(is_subset(&[1], &[1, 2]), true);
assert_eq!(is_subset(&[2], &[1, 2]), true);
assert_eq!(is_subset(&[3], &[1, 2]), false);
assert_eq!(is_subset(&[1, 2], &[1]), false);
assert_eq!(is_subset(&[1, 2], &[1, 2]), true);
assert_eq!(is_subset(&[1, 2], &[2, 3]), false);
let large = (0..1000).collect::<Vec<_>>();
assert_eq!(is_subset(&[], &large), true);
assert_eq!(is_subset(&large, &[]), false);
assert_eq!(is_subset(&[-1], &large), false);
assert_eq!(is_subset(&[0], &large), true);
assert_eq!(is_subset(&[1, 2], &large), true);
assert_eq!(is_subset(&[999, 1000], &large), false);
}
#[test]
fn test_zip() {
let mut x = BTreeSet::new();