Efficient set intersection of a collection of sets in C++
Asked Answered
O

2

10

I have a collection of std::set. I want to find the intersection of all the sets in this collection, in the fastest manner. The number of sets in the collection is typically very small (~5-10), and the number of elements in each set is is usually less than 1000, but can occasionally go upto around 10000. But I need to do these intersections tens of thousands of time, as fast as possible. I tried to benchmark a few methods as follows:

  1. In-place intersection in a std::set object which initially copies the first set. Then for subsequent sets, it iterates over all element of itself and the ith set of the collection, and removes items from itself as needed.
  2. Using std::set_intersection into a temporary std::set, swap contents to a current set, then again find intersection of the current set with the next set and insert into the temp set, and so on.
  3. Manually iterate over all the elements of all sets like in 1), but using a vector as the destination container instead of std::set.
  4. Same as in 4, but using a std::list instead of a vector, suspecting a list will provide faster deletions from the middle.
  5. Using hash sets (std::unordered_set) and checking for all items in all sets.

As it turned out, using a vector is marginally faster when the number of elements in each set is small, and list is marginally faster for larger sets. In-place using set is a substantially slower than both, followed by set_intersection and hash sets. Is there a faster algorithm/datastructure/tricks to achieve this? I can post code snippets if required. Thanks!

Orthoclase answered 13/10, 2012 at 18:57 Comment(8)
The question really depends on whether or not you are expected to find many common elements or not, as this alters the "best" structure that one can come up with. For example, a 6th method could be to simply use and std::unordered_map and count the number of occurrences of each elements. It's O(N) in the total number of elements. Then, you just pick the elements that have a total equal to the number of sets, O(M) in the number of distinct elements. No idea how well it would perform.Arneson
@MatthieuM. I see. I will give this a try, though I suspect, it won't be faster than a std::list due to hashing and other overheads. Thanks!Orthoclase
@MatthieuM. This method will give the resulting set in unsorted order. Luckily, I have two use cases, one which requires the result in sorted order, and one which does not. If this method is reasonably fast, I can use it atleast for the case where the intersection is not needed to be sorted.Orthoclase
@MatthieuM. I tried this approach, and for my data, this was only slightly faster than my approach 5 (using unordered_set).Orthoclase
You could try this idea. Worst case linear (can't avoid that, if the sets have mostly the same elements), but if the intersection is small, it can be much faster.Clausewitz
@DanielFischer Thank you! Due to Dietmar's answer below, I had also thought about using a binary search when doing searches in arrays. But the worst case slowdowns was a worry. You propose a very nice heuristic/estimation to make this a hybrid approach. Indeed, this is only marginally slower than the vector approach (pt 3 above) due to minor extra computations, but clearly the fastest among all if the sizes of subsequent sets is sufficiently larger than the current one! Very nice idea!Orthoclase
@DanielFischer I would have accepted this if it was an answer.Orthoclase
could we have a look at the source of that test?Gormand
F
11

You might want to try a generalization of std::set_intersection(): the algorithm is to use iterators for all sets:

  1. If any iterator has reached the end() of its corresponding set, you are done. Thus, it can be assumed that all iterators are valid.
  2. Take the first iterator's value as the next candidate value x.
  3. Move through the list of iterators and std::find_if() the first element at least as big as x.
  4. If the value is bigger than x make it the new candidate value and search again in the sequence of iterators.
  5. If all iterators are on value x you found an element of the intersection: Record it, increment all iterators, start over.
Furfuraceous answered 13/10, 2012 at 19:16 Comment(24)
I would not recommend std::find_if when one is working with std::set, after all, std::set features both std::lower_bound and std::upper_bound with are typically faster.Arneson
@MatthieuM. not in this case, find_if will on average never need to advance more than two elements and is thus O (1), while ???er_bound is O (log n).Rakel
@MatthieuM.: Obviously, it depends on the interface of the algorithm and I would operate on a sequence of pairs of input iterators: std::set_intersection() does as well. Interestingly, I think the complexity of your suggested approach is O((n log n) * m) : wheren is the maximum size of the sets and m is the number of sets. My algorithm has complexity O(n * m). I think my approach wins.Epencephalon
@Rakel Thanks! I did not understand why find_if will on average never need to advance more than two elements?Orthoclase
@leftaroundabout: like Paresh I wonder where the 2 elements come from (I might be missing something obvious). It seems to me it would depend on how the data is distributed, would it not ? For example suppose than I have a set of 100 elements and another of 1000 elements covering the same range. Then in average I will need to skip about 10 elements from the large set at each step.Arneson
@Paresh, Matthieu M.: actually, that was a heuristic argument, thinking it over I'm not that convinced anymore. It does hold if all the sets hold roughly the same number of values from the same random distribution, but if one of the sets happens to have quite a lot of values in between to values of the first set, it doesn't work this way.Rakel
@MatthieuM., Paresh: still, I think it should also work on average for some class of distributions for each set seperately, but I'm not sure if I can proove it.Rakel
@DietmarKühl: I think it really depends on the data distribution. With two sets of 100 and 10000 elements each, my approach involves 100 lookups in the 10000 elements set (13 comparisons each in average), thus about 1300 comparisons. Your approach involves 200000 comparisons.Arneson
@leftaroundabout: left with the same issue, I'll need to think it again come morning with a clear head :/Arneson
@DietmarKühl: I must admit I really like your solution, especially because it's so simple. However it bugs me, somehow, that the "best" algorithm would have linear complexity when were are talking about sorted data. I mean, assuming a good hash function, it's the same complexity you could expect from using a hash-set/hash-map base mechanism; even though it's best in terms of memory allocation obviously since you are not allocating anything.Arneson
The most basic algorithm (going through the first set element by element and searching for the elements in all the others) has complexity O (n * (log n)^m). You probably can build an efficient algorithm using lower_bound.Dade
@DietmarKühl If I understood correctly, in step 4, even if for some iterator, the value found is greater than x, we can keep looking and keep track of the greatest value found till the end, and make that the new x. Am I right about this?Orthoclase
@Dade Ah! lower_bound uses binary search. But the exponent looks scary! :) I guess when I implement this, I'll try out with both find_if and lower_bound.Orthoclase
@JohnB: Yes, I don't recommend the naive approach. However had we access to the internals of the set iterator, we could skip a number of comparisons -> I am not convinced that the visitation mechanism used here (left sub-tree, element, right sub-tree) is the best that we can achieve.Arneson
Sorry, complexity of the most basic algorithm should have been O (n * (m-1) * log n). I doubt, however, that you will get any better than that, because you probably cannot avoid looking at each element in at least one set once. Consider, e.g., taking the intersection of {1,3,5}, {1,3,5}, and {2,4,6}. If you have bad luck, you will detect the "mismatch" only at having considered all other sets, in each step of your algorithm. The most efficient optimization might be always starting with the smallest set (i.e. the one having the least number of elements [left]).Dade
@Paresh: yes, you can use the biggest next cslue as the next candidate: no other value csn possibly be in thr intersection. Also, if the sizes of the set are available it may be reasonable to order them by size, looking for candidates starting at the smallest one. This way using bigger strides on the larger cintaibers using logarithmic search may be an advantage.Epencephalon
@DietmarKühl Whoa! For my data, this turned out to be the slowest so far - anywhere between 5x to 50x slower. I am reasonably sure that the implementation was decent (or atleast as decent as for the other approaches). I guess the extra log n factor is really killing the performance. The next element should more than likely be very near the current position, so binary search goes very near the full log n comparisons, whereas maybe find_if would be faster. I am trying to modify it to use find_if instead of lower_bound, but am stuck at making a comparator for checking with current best.Orthoclase
There must be something wrong if this is actually slow! It should be faster than using multiple std::set_intersection(). To create a suitable predicate, you should be able to use std::bind1st(std::less_equal<T>(), x).Epencephalon
@DietmarKühl Instead of using find_if, I manually incremented the iterator till the appropriate value (I guess find_if does exactly this). This way, the time was brought down significantly. It is now faster than repeated set_intersection, but still slower than the two fastest (3 and 4 in the question). All I did was change iterators[i] = lower_bound(iterators[i], sets[i].end(), currentValue, comparator); to while (iterators[i] != sets[i].end() && *(*iterators[i]) > *currentValue) ++iterators[i]; which is basically replacing lower_bound by find_ifOrthoclase
Please note the set is in descending order, hence the > sign instead of the expected < sign. Also, the set consists of pointers, hence the double de-referencing.Orthoclase
@DietmarKühl I suspect, at the cost of code complexity, this can be made faster in the following way: If we find a x greater than the current best, we invalidate all previous checks, and for the subsequent iterations, search for this new x. At the end of the list, we start from the beginning again as they have been invalidated due to current x being greater than what was checked on them. This will result in big jumps (it is the small jumps that is causing problems for the log n lower_bound). The other improvement could be the one you suggested: order the sets (haven't done that in this)Orthoclase
@DietmarKühl I did some more digging in, and found that the massive increase in time for the lower_bound approach was because I was using the global/algorithms lower_bound. When I switched to the std::set::lower_bound, the time dropped drastically to be comparable to repeated set_intersect. However, it was still not fast enough as the linear increment similar to find_if as described above. I suppose the two lower_bound functions use iterator differently (random vs forward), or some such reason.Orthoclase
@DietmarKühl With some more tweaking, the lower_bound (binary search) based approach is faster than find_if (linear) if the number of intersections are small, but is slower if number of intersections is large. I guess this is to be expected. Overall, these are close, but slower than the approaches 3 and 4 in the question, but much faster than the other approaches listed.Orthoclase
Accepting this since it has scope for further optimization, and can be used for early termination if for some use case, only the first few elements of the intersection are needed.Orthoclase
A
6

Night is a good adviser and I think I may have an idea ;)

  • Memory is much slower than CPU these days, if all data fits in the L1 cache no big deal, but it easily spills over to L2 or L3: 5 sets of 1000 elements is already 5000 elements, meaning 5000 nodes, and a set node contains at least 3 pointers + the object (ie, at least 16 bytes on a 32 bits machine and 32 bytes on a 64 bits machine) => that's at least 80k memory and the recent CPUs only have 32k for the L1D so we are already spilling into L2
  • The previous fact is compounded by the problem that sets nodes are probably scattered around memory, and not tightly packed together, meaning that part of the cache line is filled with completely unrelated stuff. This could be alleviated by provided an allocator that keeps nodes close to each others.
  • And this is further compounded by the fact that CPUs are much better at sequential reads (where they can prefetch memory before you need it, so you don't wait for it) rather than random reads (and a tree structure unfortunately leads to quite random reads)

This is why where speeds matter, a vector (or perhaps a deque) are so great structures: they play very well with memory. As such, I would definitely recommend using vector as our intermediary structures; although care need be taken to only ever insert/delete from an extremity to avoid relocation.

So I thought about a rather simple approach:

#include <cassert>

#include <algorithm>
#include <set>
#include <vector>

// Do not call this method if you have a single set...
// And the pointers better not be null either!
std::vector<int> intersect(std::vector< std::set<int> const* > const& sets) {
    for (auto s: sets) { assert(s && "I said no null pointer"); }

    std::vector<int> result; // only return this one, for NRVO to kick in

    // 0. Check obvious cases
    if (sets.empty()) { return result; }

    if (sets.size() == 1) {
        result.assign(sets.front()->begin(), sets.front()->end());
        return result;
    }


    // 1. Merge first two sets in the result
    std::set_intersection(sets[0]->begin(), sets[0]->end(),
                          sets[1]->begin(), sets[1]->end(),
                          std::back_inserter(result));

    if (sets.size() == 2) { return result; }


    // 2. Merge consecutive sets with result into buffer, then swap them around
    //    so that the "result" is always in result at the end of the loop.

    std::vector<int> buffer; // outside the loop so that we reuse its memory

    for (size_t i = 2; i < sets.size(); ++i) {
        buffer.clear();

        std::set_intersection(result.begin(), result.end(),
                              sets[i]->begin(), sets[i]->end(),
                              std::back_inserter(buffer));

        swap(result, buffer);
    }

    return result;
}

It seems correct, I cannot guarantee its speed though, obviously.

Arneson answered 14/10, 2012 at 12:12 Comment(2)
Thanks! The compactness of memory was the reason I tried the option 3 in the original question: using a vector as an intermediate container, just as you have done. The difference being you used the set_intersection, which requires two vectors, while I kept 1 vector, with the disadvantage that I had to erase from the middle. Even though your approach should ideally have been faster, I guess the complex factors like contiguous memory, caching (1 array vs 2) etc are making this slower than the options 3 and 4 that I tried above. Of course, mileage may vary based on the data.Orthoclase
+1 for thinking in terms of memory and caching, and giving a nice explanation! As a side note, I am considering using vectors instead of std::set, and inserting in sorted order into vectors if that is comparable. Compactness may make it reasonably fast, and intersections would definitely be faster.Orthoclase

© 2022 - 2024 — McMap. All rights reserved.