Why is this binary search optimization much slower?
Asked Answered
R

2

8

A supposed optimization made the code over twice as slow.

I counted how often a value x occurs in a sorted list a by finding the range where it occurs:

from bisect import bisect_left, bisect_right

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

But hey, it can't stop before it starts, so we can optimize the second search by leaving out the part before start (doc):

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

But when I benchmarked, the optimized version took over twice as long:

 254 ms ±  1 ms  original
 525 ms ±  2 ms  optimized

Why?

The benchmark builds a sorted list of ten million random ints from 0 to 99999, and then counts all different ints (just for benchmarking, no use to point out Counter) (Try it online!):

import random
from bisect import bisect_left, bisect_right
from timeit import repeat
from statistics import mean, stdev

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

def count_all():
    for x in unique:
        count(a, x)
for count in original, optimized:
    times = repeat(count_all, number=1)
    ts = [t * 1e3 for t in sorted(times)[:3]]
    print(f'{round(mean(ts)):4} ms ± {round(stdev(ts)):2} ms ', count.__name__)
Reproachful answered 8/7, 2022 at 15:59 Comment(16)
Very strange. I've reproduced this here as well. 148ms vs 283ms.Skimmer
The only difference is the extra overhead of passing the parameter, but that should be negligible compared to the cost of the search.Skimmer
And I tried changing original to use stop = bisect_right(a, x, start-start) so it would pass an explicit second argument. That made it a tiny bit slower, but still nothing like optimized.Skimmer
Is there somehow a caching bonus when both calls run on exactly the same (segment of) the array?Frantic
I'm thinking something similar, it has to do with locality.Skimmer
Another idea: We "know" something about this particular data which the binary search algorithm doesn't, i.e. that for most values of x, stop will not be very far from start (on average, 100 positions away in a list of 10 million items). With that in mind, does giving a "hint" actually help the algorithm? Effectively that makes the search try first at ((len(a) + start) / 2), which on average will be too far to the right. Does this turn out to be actually worse than starting at len(a)/2?Frantic
@Frantic I don't think it's "too far to the right", although I'm not really sure what you mean. The stop points on average are to the right of len(a)/2. And reducing the search range reduces the number of halvings needed.Reproachful
What I was thinking was: let's say that for some x, start=8e6 (with this list of size 1e7). Then with optimized, you start looking for stop at 9e6, which is too far to the right except in the unlikely case that there are 1e6 instances of the relevant value in the list. But I'm wrong that this would explain the time difference, because the optimized version results in fewer total comparisons being done:Frantic
tio.run/…Frantic
(my_bisect_left and my_bisect_right there are copied from the Python source code, except for adding the comparison counter)Frantic
@Frantic I did a similar test, where I instead make the objects count their comparisons (reduced numbers, since that makes it slower). It agrees, fewer comparisons with the optimized version:Reproachful
tio.run/…Reproachful
I wonder if it's that without the hint, bisect_right hits largely the same array elements that bisect_left did, and gets a caching benefit from that. Let's say that our value x occurs at positions 850 through 900 in a 10k element array. bisect_left accesses elems 5000, 2500, 1250, 625, 938, 782, 860, 821, 841, 851, 846, 849, 850. "Unoptimized" bisect_right tries: 5000, 2500, 1250, 625, 938, 782, 860, 899, 919, 909, 904, 902, 901, 900. So the first 6 of "unoptimized" bisect_right's 14 list accesses have just been done by bisect_left: are they available in cache, hence the time difference?Frantic
@Frantic Probably. Would be nice if someone could measure the number/time of cache/page hits/misses. (I have no experience with that, and currently only have a phone, no working PC.)Reproachful
@KellyBundy: On a program-wide basis, Python is too random to make that all that useful (generating the data creates many page faults). That said, a minor tweak that reduces the impact of any caching issues is to ensure all ints with the same value are the same object, and that they're generated in order all at once (so they're likely mostly contiguous in memory). All you do is change range(100_000) to tuple(range(100_000). Doing that improves "original" by ~34%, and "optimized" by ~61%. Reducing the impact of the cache improves optimized more, implying it's in fact a cache miss issue.Haynor
@Haynor Hmm, yes, ideally one would only profile the search part of the program. Perhaps by waiting for key press before and after, and starting/stopping the profiling there. Or by profiling the generation separately and subtracting it out. Don't know what is possible and works well. Your test+argument is nice, might be good in an answer. (Actually the main reason I only upvoted Tim's but haven't accepted it yet is that I'd like to see at least some actual demonstration of the cache effects... That's partly why I asked the question in the first place.)Reproachful
M
2

There are a couple respects in which the benchmark triggers adverse cache effects.

First, I bet this assert will pass for you (as it does for me):

assert list(unique) == sorted(unique)

There's no guarantee that will pass, but given the implementations of CPython's set type and integer hashing to date, it's likely to pass.

That implies your for x in unique is trying x in strictly increasing order. That makes the potential probe sequences inside bisect_left() much the same from one x to the next, so many of the values being compared are likely sitting in cache. The same is true of bisect_right() in the original, but in the optimized version the potential probe sequences for bisect_right() differ across tries because the start index differs across tries.

To make both versions slow down "a lot", add this after the assert:

unique = list(unique)
random.shuffle(unique)

Now there's no regularity in the input x across tries, so no systemic correlation either in potential probe sequences across tries.

The other cache effects come within a single try. In the original, the potential probe sequences are exactly the same between bisect_left() and bisect_right(). Entries read up to resolve bisect_left() are very likely still sitting in cache for bisect_right() to reuse.

But in the optimized version, the potential probe sequences differ because the slice bounds differ. For example, bisect_left() will always start by comparing x to a[5000000]. In the original, bisect_right() will also always start by making that same compare, but in the optimized version will almost always pick a different index of a to start with - and that one will be waiting in cache purely by luck.

All that said, I usually use your optimization in my own code. But that's because I typically have comparison operations far more expensive than integer compares, and saving a compare is worth a lot more than saving some cache misses. Comparison of small ints is very cheap, so saving some of those is worth little.

Manon answered 8/7, 2022 at 17:42 Comment(7)
Ha, originally I had shuffled, in order to avoid the set's "sorting". Then removed that to simplify the benchmark, as the speed ratio remained similar. I think there's still another cache effect, also pointed out by slothrop: even shuffled, the "top" indexes in the bisection tree remain used all the time (when using the full range, but not when using the start).Reproachful
Yes, that's what I was getting in the second half of my answer: "The other cache effects come within a single try ...".Manon
Oh is the "within a single try" not referring to a single bisect_left+bisect_right pair? That's not what I mean. I'm talking about across all the searches. Consider all bisect_left calls. Even when searching in shuffled order, the top index is always the same, the two second-to-top indexes each are accessed in half of all searches, etc.Reproachful
It's referring to both, actually. There are no qualifiers in ``For example, bisect_left() will always start by comparing x to a[5000000]. In the original, bisect_right() will also always start by making that same compare, but ...", That much is true both within and across tries.Manon
It is worth noting that the "other cache effects within a single try" are not merely a result of a poor benchmark. That's a natural consequence of the optimization itself. You're correct that it's specific to cheap comparisons though; if I tweak the benchmark to make comparisons just a little more expensive (by making an int subclass with an overridden __lt__ that returns NotImplemented unconditionally, so it ends up performing a useless call into Python, then falling back to the C implementation of __gt__), as well as using unique instances for each value, optimized ties it up.Haynor
Try it online!Haynor
@Tim Ok... I understood the "within a single try" as such a qualifier. Sounded like you were only talking about within each left+right search pair, and the subsequent discussion doesn't widen that impression, as it focuses on hits/misses in bisect_right.Reproachful
R
0

I tried simulating caching in Python and measured cache misses for various cache sizes:

        ORIGINAL:              OPTIMIZED:
cache |         cache-misses |         cache-misses |
 size |  time   line   item  |  time   line   item  |  
------+----------------------+----------------------+
 1024 | 1.98 s  59.1%  16.4% | 4.90 s  74.4%  57.1% |
 2048 | 2.30 s  59.1%  16.4% | 5.28 s  72.5%  56.8% |
 4096 | 2.16 s  59.0%  16.4% | 5.30 s  70.4%  56.4% |
 8192 | 2.33 s  59.0%  16.4% | 6.09 s  68.2%  56.0% |
16384 | 2.80 s  59.0%  16.4% | 6.30 s  65.8%  55.6% |

I used a proxy object for the list. Getting a list item goes through the getitem function, which has an LRU cache of the size shown in the leftmost column. And getitem doesn't access the list directly, either. It goes through the getline function, which fetches a "cache line", a block of 8 consecutive list elements. It has an LRU cache of the cache size divided by 8.

It's far from perfect, i.e., from the real thing, measuring the real cache misses, especially since it only simulates caching the references in the list but not the list element objects. But I find it interesting nonetheless. The original version of my function shows fewer cache misses, and the miss rates appear to be pretty constant throughout the various cache sizes. The optimized version shows more cache misses, and larger cache size helps to reduce the miss rates.

My code (Try it online!):

import random
from bisect import bisect_left, bisect_right
from timeit import timeit
from statistics import mean, stdev
from functools import lru_cache

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

class Proxy:
    __len__ = a.__len__
    def __getitem__(self, index):
        return getitem(index)
p = Proxy()

def count_all():
    for x in unique:
        count(p, x)

linesize = 8

print('''        ORIGINAL:              OPTIMIZED:
cache |         cache misses |         cache misses |
 size |  time   line   item  |  time   line   item  |  
------+----------------------+----------------------+''')

for cachesize in 1024, 2048, 4096, 8192, 16384:
    print(f'{cachesize:5} |', end='')

    @lru_cache(cachesize // linesize)
    def getline(i):
        i *= linesize
        return a[i : i+linesize]
    
    @lru_cache(cachesize)
    def getitem(index):
        q, r = divmod(index, linesize)
        return getline(q)[r]
    
    for count in original, optimized:
        getline.cache_clear()
        getitem.cache_clear()
        time = timeit(count_all, number=1)
        def misses(func):
            ci = func.cache_info()
            misses = ci.misses / (ci.misses + ci.hits)
            return f'{misses:.1%}'
        print(f'{time:5.2f} s  {misses(getline)}  {misses(getitem)}', end=' |')
    print()
Reproachful answered 9/7, 2022 at 0:3 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.