Calling std::nth_element() function extremely frequently
Asked Answered
T

4

19

I did not find this specific topic anywhere...

I am calling the nth_element() algorithm about 400,000 times per second on different data in a std::vector of 23 integers, more precise "unsigned short" values.

I want to improve computation speed and this particular call needs a significant portion of CPU time. Now I noted, as with std::sort(), that the nth_element function is visible in the profiler even with highest optimisation level and NDEBUG mode (Linux Clang compiler), so the comparison is inlined but not the function call itself. Well, more preise: not nth_element() but std::__introselect() is visible.

Since the size of the data is small, I experimented with using a quadratic sorting function PIKSORT, which is often quicker than calling std::sort when the size of data is less than 20 elements, probably because the function will be inline.

template <class CONTAINER>
inline void piksort(CONTAINER& arr)  // indeed this is "insertion sort"
{
    typename CONTAINER::value_type a;

    const int n = (int)arr.size();
    for (int j = 1; j<n; ++j) {
        a = arr[j];
        int i = j;
        while (i > 0 && a < arr[i - 1]) {
            arr[i] = arr[i - 1];
            i--;
        }
        arr[i] = a;
    }
}

However this was slower than using nth_element in this case.

Also, using a statistical method is not appropriate, Something faster than std::nth_element

Finally, since the values are in the range from 0 to about 20000, a histogram method does not look appropriate.

My question: does anyone know a simple solution to this? I think I am probably not the only one to have to call std::sort or nth_element very frequently.

Thrombus answered 23/10, 2015 at 17:11 Comment(21)
Is the order variable, or would it be possible to sort the vector once?Trojan
The typical name for your algorithm is "insertion sort", and your first loop's bounds are incorrect (what about arr[0]?)Sandrocottus
Are all the vectors completely different? Of they are obtained one from another by minor modifications (e.g. added an element, removed an element)?Singularity
You have mentioned that size of vector is usually 23. What can be said about the index (in sorted order) of the element you are trying to find?Singularity
1. We simply search the median value, i.e. the value of the middle element value[11] in the sorted vector. 2. Thank you for the name. However the iteration j=0 seem useless for me. The code seems to work fine in the current state. 3. @stgatilov: yes, it is always new data, in no known pre-order.Thrombus
If piksort is faster it's due to cache friendliness, not inlining. And sort and nth_element implementations are quite different. So I would suspect that nth_element is more cache friendly. You could try implementing actual piko_nth_element instead of full fledged sort. Also this piksort looks like bubble sort, or am I missing something.Medamedal
Can't you just sort the data once using insertion sort and answer each query in constant time just by reading the element at given index?Ossie
Is this a median filter for an image by chance?Rumble
The data are new in each call. And it is image data, yes. It has to do with edge features: we find possible edges, determine orientation, and calculate the median for each a Bresenham line outside and inside. Currently I cannot see how I can precompute it.Thrombus
Looking at std::nth_element, I does not only find the k-th element in sorted order, it also performs a partition of input array into less-than and greater-than elements. Do you want to get this partition? Would it be OK for you to only obtain the median value without changing the input array (i.e. non-destructive version)?Singularity
Have a look at this answer from #811157. It claims to be much faster than std::nth_element. Blog-post here: disnetwork.info/the-blog/median-value-selection-algorithmPentad
@PeterCordes Thank you for this hint, I will try this.Thrombus
@PeterCordes I adapted the algorithm of Christophe Meessen in the Blog for 23 unsigned shorts instead of 27 floats, but is was not faster than nth_element on my Intel i5 CPU.Thrombus
@Singularity Yes I only need the median value, not the partitioning. Currently the quickest method is the network sort, the HeapMedian3 method was not faster in my test (although I would have liked a factor of 10 improvement). From intuition I would say, that the network sort still does more work than necessary.Thrombus
@karsten: you can leave out any compare&swaps that can't affect the median element of the final output. This may not help much if you're using SIMD for the sorting network. (x86 SSE2 can do 8 packed 16bit integers per vector. AVX2 can do 16, but that might not be helpful for sorting 23 elements.)Pentad
@karsten: What about my vectorized O(N^2) solution? It does much more work than necessary, but much faster =) Also, it is not tied to n=23 case.Singularity
@Singularity Yes, many thanks, I will try it out tomorrow.Thrombus
Speaking of the linked question, I advise looking into the accepted answer. It references a great paper which accelerates median filter for image processing. It solves a more specific problem than the one you have asked, but it may help you more than our answers.Singularity
interesting paper, thanksThrombus
How are the integers distributed? Uniformly?Unwilling
@Unwilling yes, more or less.Thrombus
K
19

You mentioned that the size of the array was always known to be 23. Moreover, the type used is unsigned short. In this case, you might try to use a sorting network of size 23; since your type is unsigned short, sorting the whole array with a sorting network might be even faster than partially sorting it with std::nth_element. Here is a very straightforward C++14 implementation of a sorting network of size 23 with 118 compare-exchange units, as described by Using Symmetry and Evolutionary Search to Minimize Sorting Networks:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}

The swap_if utility function compares two parameters x and y with the predicate compare and swaps them if compare(y, x). My example uses a a generic swap_if function, but you can used an optimized version if you known that you will be comparing unsigned short values with operator< anyway (you might not need such a function if your compiler recognizes and optimizes the compare-exchange, but unfortunately, not all compilers do that - I am using g++5.2 with -O3 and I still need the following function for performance):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}

Now, just to make sure that it is indeed faster, I decided to time std::nth_element when required to partial sort only the first 10 elements vs. sorting the whole 23 elements with the sorting network (1000000 times with different shuffled arrays). Here is what I get:

std::nth_element    1158ms
network_sort23      487ms

That said, my computer has been running for a bit of time and is a bit slow, but the difference in performance is neat. I believe that this difference will remain the same when I restart my computer. I may try it later and let you know.

Regarding how these times were generated, I used a modified version of this benchmark from my cpp-sort library. The original sorting network and swap_if functions come from there as well, so you can be sure that they have been tested more than once :)

EDIT: here are the results now that I have restarted my computer. The network_sort23 version is still two times faster than std::nth_element:

std::nth_element    369ms
network_sort23      154ms

EDIT²: if all you need in the median, you can trivially delete the compare-exchange units that are not needed to compute the final value that will be at the 11th position. The resulting median-finding network of size 23 that follows uses a different size-23 sorting network than the previous one, and it yields slightly better results:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);

There are probably smarter ways to generate median-finding networks, but I don't think that extensive research has been done on the subject. Therefore, it's probably the best method you can use as of now. The result isn't awesome but it still uses 104 compare-exchange units instead of 118.

Kentonkentucky answered 24/10, 2015 at 14:38 Comment(11)
@karsten: if you're targeting an architecture with vector instructions, sorting networks can often be sped up with vectors. e.g. packed_min(xvec, yvec) and packed_max(xvec, yvec) does multiple swap_ifs in parallel. x86's SSE2 has an instruction for min/max of signed words (16bit vector elements). SSE4.1 is needed for unsigned words, or anything other than unsigned bytes / signed words. I assume ARM NEON has something similar. Shuffling is the bottleneck in vector sorting networks, though, to line up elements for comparisons.Pentad
Hello Morwenn, this is an interesting idea.Thrombus
This is the kind of question & answer that makes dredging through the pile of "do my homework" or "debug my program" posts that SO attracts today worthwhile.Breslau
Hello Morwenn, this is an interesting idea! I tried it out (however not using C++11), and it was significantly quicker. Your code for swap_if was much quicker than my first simple implementation static inline void swap_if(unsigned short& x, unsigned short& y) { if (y < x) std::swap(x,y); } although it is not obvious to me why. Interestingly, I first tried out to measure speed using the tool "callgrind" from the valgrind suite (I was quite put down, since it indicated it would take longer), but when I measured the plain execution times it was much quicker! Thank you very much.Thrombus
@MichaelBurr : Nice to hear.Thrombus
@Thrombus Apparently it's a missed optimization opportunity from GCC. I have submitted a bug report but I have to admit that I didn't invent this swap_if implementation, I took it from this answer and timed it to make sure it was worth it.Kentonkentucky
@Thrombus I updated my anwer with a slightly more problem-specific network, even if it's no more than a small improvement. I simply removed compare-exchange units from a sorting network so that nothing useless is done once the median first[11u] is in its correct place.Kentonkentucky
@Kentonkentucky Thank you, I will test it today and post the result.Thrombus
yes, it is slightly faster. By the way: why is your swap_if faster than { if (y<x) swap(x,y); } ? Perhaps because no branch is involved?Thrombus
@Thrombus I've already told you all I know about swap_if four comments ago :pKentonkentucky
@Kentonkentucky sorry didn't see your previous comment on swap_ifThrombus
S
5

General idea

Looking at source code of std::nth_element in MSVC2013, it seems that cases of N <= 32 are solved by insertion sort. It means that STL implementors realized that doing randomized partitions would be slower despite better asymptotics for that sizes.

One of the ways to improve performance is to optimize sorting algorithm. @Morwenn's answer shows how to sort 23 elements with a sorting network, which is known to be one of the fastest ways to sort small constant-sized arrays. I'll investigate the other way, which is to calculate median without sorting algorithm. In fact, I won't permute the input array at all.

Since we are talking about small arrays, we need to implement some O(N^2) algorithm in the simplest way possible. Ideally, it should have no branches at all, or only well-predictable branches. Also, simple structure of the algorithm could allow us to vectorize it, further improving its performance.

Algorithm

I have decided to follow the counting method, which was used here to accelerate small linear search. First of all, suppose that all the elements are different. Choose any element of the array: number of elements less than it defines its position in the sorted array. We can iterate over all elements, and for each of them calculate number of elements less than it. If the sorted index has desired value, we can stop the algorithm.

Unfortunately, there may be equal elements in general case. We'll have to make our algorithm significantly slower and more complex to handle them. Instead of calculating the unique sorted index of an element, we can calculate interval of possible sorted indices for it. For any element, it is enough to count number of elements less than it (L) and number of elements equal to it (E), then sorted index fits range [L, L+R). If this interval contains desired sorted index (i.e. N/2), then we can stop the algorithm and return the considered element.

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: https://mcmap.net/q/86123/-fastest-way-to-determine-if-an-integer-is-between-two-integers-inclusive-with-known-sets-of-values
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}

Vectorization

The constructed algorithm has only one branch, which is rather predictable: it fails in all cases, except for the only case when we stop the algorithm. The algorithm is easy to vectorize using 8 elements per SSE register. Since we'll have to access some elements after the last one, I'll assume that the input array is padded with max=2^15-1 values up to 24 or 32 elements.

The first way is to vectorize inner loop by j. In this case inner loop would be executed only 3 times, but two 8-wide reductions must be done after it is finished. They eat more time than the inner loop itself. As a result, such a vectorization is not very efficient.

The second way is to vectorize outer loop by i. In this case we process 8 elements x = arr[i] at once. For each pack, we compare it with each element arr[j] in inner loop. After the inner loop we perform vectorized range check for the whole pack of 8 elements. If any of them succeeds, we determine exact number with simple scalar code (it eats little time anyway).

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}

Here we see _mm_set1_epi16 intrinsic in the innermost loop. GCC seems to have some performance issues with it. Anyway, it is eating time on each innermost iteration, which can be reduced if we process 8 elements at once in the innermost loop too. In such case we can do one vectorized load and 14 unpack instructions to obtain vAll for eight elements. Also, we'll have to write compare-and-count code for eight elements in loop body, so it acts as 8x unrolling too. The resulting code is the fastest one, a link to it can be found below.

Comparison

I have benchmarked various solutions on Ivy Bridge 3.4 Ghz processor. Below you can see total computation time for 2^23 ~= 8M calls in seconds (the first number). Second number is checksum of results.

Results on MSVC 2013 x64 (/O2):

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j

Results on MinGW GCC 4.8.3 x64 (-O3 -msse4):

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)

As you see, the proposed vectorized algorithm for 23 16-bit elements is a bit faster than sorting-based approach (BTW, on an older CPU I see only 5% time difference). If you can guarantee that all elements are different, you can simplify the algorithm, making it even faster.

The full code of all algorithms is available here, including all the testing code.

Singularity answered 25/10, 2015 at 3:27 Comment(7)
I added another tweak to my solution so that it is a bit more specific. Honestly, it shouldn't change the results but I didn't manage to run your benchmarks because of some inlining problems. Would you mind launching them again to test the new solution? :)Kentonkentucky
@Morwenn: Did you add -msse4 switch when compiling? If yes, you may also try to remove FORCEINLINE and everything related to it, perhaps it would help. I'll update benchmark code and timings a bit later, I promise.Singularity
Yes, I tried both, but for some reason, only times when ot compiled and ran, it didn't print anything as if it silently crashed before reaching the printing instructions. I may try again, but still...Kentonkentucky
@Morwenn: I have updated the answer, the difference is smaller now. As for a crash before printing anything, it is very strange, because the first info printed is timing information for memcpy, and I'm pretty sure that are no errors up to that moment. It would be great to hear from some independent person, if he can compile and run my code.Singularity
@Singularity Result with your full source code: clang++ -msse4 -std=c++11 -O3 sorting.cpp -o sorting clang version 3.5.0 Linux savitri 3.16.7-24-desktop #1 SMP PREEMPT CPU: Intel Core i5-2410M CPU @ 2.30GHz memcpy only: 0.061 std::nth_element: 2.575 (974340096) bubble sort: 5.202 (974340096) selection sort: 2.859 (974340096) network sort: 0.777 (974340096) trivial count: 1.440 (974340096) vectorized count: 0.723 (974340096) vectorized count (n<=24): 0.705 (974340096) vectorized count (T): 0.707 (974340096) vectorized count (both): 0.440 (974340096) So both-vectorised-method is fastest.Thrombus
However, in my real program I could not use this since I am currently limited to SSE2 commands.Thrombus
@karsten: the both solution uses only SSE2. Just extract it into a separate cpp file, and you'll be able to compile and run it (see extracted version).Singularity
P
2

I found this problem interesting, so I tried all the algorithms I could think of.
Here are the results:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms

I think we can conclude, that you where already using the fastest algorithm. Boy was I wrong. However, if you can accept an approximate answer, there are probably faster ways, such as median of medians.
If you are interested, the source is here.

Pascha answered 24/10, 2015 at 13:27 Comment(3)
Thanks, nice comparison. You o not happen to know code for a hashset with block allocator?Thrombus
A hashset is part of the standard library, under the name std::unordered_set. It supports custom allocators, many good block allocators exists out in the 'wild', see for example herePascha
yes, I know. But we currently cannot use c++11 commands here.Thrombus
A
2

There are about two more possible variants, I would try (in addition to the SIMD based sorting network for parallel loads).

The first is an SIMD based rank filter, which records at least 12 smallest values so far (assuming the median is needed), dropping off the 11 largest values. Then the largest value left in the rank filter must be the median.

The trick to do this fast with SIMD is surprisingly trivial:

// initialise rank filter with  sorted = [oo oo oo oo oo ...], oo = inf
// repeat for every new element
// then at some point after the elements 0,0,1,44 have been inserted
// and we are inserting say, the value of 22

sorted      = [00 00 01 44 oo oo oo oo|oo oo oo oo oo oo oo oo]
---------------------------------------------------------------
shifted     = [00 00 00 01 44 oo oo oo|oo oo oo oo oo oo oo oo]
max 22      = [22 22 22 22 44 oo oo oo|oo oo oo oo oo oo oo oo]
min sorted  = [00 00 01 22 44 oo oo oo|oo oo oo oo oo oo oo oo]

In the case of median, one needs to have an array of 2 simds covering 16 elements -- and one needs to shift across the SIMD boundaries.

If the Nth element however falls into anything between 1-8 (or 17 to 24), then only one 16-byte SIMD register is needed.

To extract the Nth element takes 22 iterations of those 3 macro instructions, since the first element doesn't need sorting: shifted = [00 shifted](1:16); shifted = max(22, shifted); sorted = min(sorted, shifted);, which should take just 6 instructions when the sorted array consists of two SIMD registers.

(It's also possible to make a three-way merging sort (batchers even/odd sort) by partially sorting the data in 3 SIMD registers, then merging, which should give some boost in speed)

The other version I'd consider is radix-2 counting sort as in (https://mcmap.net/q/666909/-fast-7x7-2d-median-filter-in-c-and-c). The parametrisation here would be 15 rounds (for uint16_t input with values less than 32768) with the input being transposed to 15 bitplanes of uint32_t. The algorithm requires to visit every bit only once, it can be slightly faster than comparing full uint16_t values. The algorithm can be even parallelised/vectorised on AVX2 with popcount32 instruction for 4, 8 or 16 parallel loads, starting from intel skylake -- and with arm7/arm64 having popcount8 for 4 parallel loads / SIMD register.

template <int N> inline
    uint32_t median32(uint32_t(&bits)[N], uint32_t mask, uint32_t threshold)
    {
        uint32_t result = 0;
        int i = 0;
        do
        {
            uint32t ones = mask & bits[i];
            uint32_t ones_size = popcount(ones);
            uint32_t mask_size = popcount(mask);
            auto zero_size = mask_size - ones_size;
            int new_bit = 0;
            if (zero_size < threshold)
            {
                new_bit = 1;
                threshold -= zero_size;
                mask = 0;
            }
            result = result * 2 + new_bit;
            mask ^= ones;
        } while (++i < N);
        return result;
    }

One will start with mask = 0b1111'1111'1111'1111'1111'111 (having 23 least significant bit set) and threshold set between 1 and 23 to find the Nth element. Each iteration step can be also modified not to compute mask_size explicitly with popcount -- it starts with 23 bits set, and it will either be ones_size or mask_size - ones_size at the end of each round, which could boost the performance for arm64. Every iteration will give one more correct bit for the nth percentile, starting from the most significant one.

Astera answered 3/9, 2023 at 6:28 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.