Why isn't there a branch prediction failure penalty in this Rust code?
Asked Answered
H

1

7

I've written this very simple Rust function:

fn iterate(nums: &Box<[i32]>) -> i32 {
    let mut total = 0;
    let len = nums.len();
    for i in 0..len {
        if nums[i] > 0 {
            total += nums[i];
        } else {
            total -= nums[i];
        }
    }

    total
}

I've written a basic benchmark that invokes the method with an ordered array and a shuffled one:

fn criterion_benchmark(c: &mut Criterion) {
    const SIZE: i32 = 1024 * 1024;

    let mut group = c.benchmark_group("Branch Prediction");

    // setup benchmarking for an ordered array
    let mut ordered_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        ordered_nums.push(i - SIZE/2);
    }
    let ordered_nums = ordered_nums.into_boxed_slice();
    group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));

    // setup benchmarking for a shuffled array
    let mut shuffled_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        shuffled_nums.push(i - SIZE/2);
    }
    let mut rng = thread_rng();
    let mut shuffled_nums = shuffled_nums.into_boxed_slice();
    shuffled_nums.shuffle(&mut rng);
    group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));

    group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

I'm surprised that the two benchmarks have almost exactly the same runtime, while a similar benchmark in Java shows a distinct difference between the two, presumably due to branch prediction failure in the shuffled case.

I've seen mention of conditional move instructions, but if I otool -tv the executable (I'm running on a Mac), I don't see any in the iterate method output.

Can anyone shed light on why there's no perceptible performance difference between the ordered and the unordered cases in Rust?

Hege answered 4/1, 2020 at 6:58 Comment(3)
I'm suspecting that this has to do with how Rust/LLVM optimizes such loops into SIMD instructions (which I believe Java is not able to do).Trawler
@Frxstrem, yes on my computer it uses the AVX ISA, even in the Rust Playground, it flattens the logic with use of the "conditional move if less than" instruction cmovllTalent
@sshashank124: yup, with full optimization enabled (-O3) modern ahead-of-time compiler back-ends like LLVM and GCC will often do "if-conversion" of branching into CMOV or other branchless sequence. That's also a pre-requisite for auto-vectorization.Cactus
S
11

Summary: LLVM was able to remove/hide the branch by using either the cmov instruction or a really clever combination of SIMD instructions.


I used Godbolt to view the full assembly (with -C opt-level=3). I will explain the important parts of the assembly below.

It starts like this:

        mov     r9, qword ptr [rdi + 8]         ; r9 = nums.len()
        test    r9, r9                          ; if len == 0
        je      .LBB0_1                         ;     goto LBB0_1
        mov     rdx, qword ptr [rdi]            ; rdx = base pointer (first element)
        cmp     r9, 7                           ; if len > 7
        ja      .LBB0_5                         ;     goto LBB0_5
        xor     eax, eax                        ; eax = 0
        xor     esi, esi                        ; esi = 0
        jmp     .LBB0_4                         ; goto LBB0_4

.LBB0_1:
        xor     eax, eax                        ; return 0
        ret

Here, the function differentiates between 3 different "states":

  • Slice is empty → return 0 immediately
  • Slice length is ≤ 7 → use standard sequential algorithm (LBB0_4)
  • Slice length is > 7 → use SIMD algorithm (LBB0_5)

So let's take a look at the two different kinds of algorithms!


Standard sequential algorithm

Remember that rsi (esi) and rax (eax) were set to 0 and that rdx is the base pointer to the data.

.LBB0_4:
        mov     ecx, dword ptr [rdx + 4*rsi]    ; ecx = nums[rsi]
        add     rsi, 1                          ; rsi += 1
        mov     edi, ecx                        ; edi = ecx
        neg     edi                             ; edi = -edi
        cmovl   edi, ecx                        ; if ecx >= 0 { edi = ecx }
        add     eax, edi                        ; eax += edi
        cmp     r9, rsi                         ; if rsi != len
        jne     .LBB0_4                         ;     goto LBB0_4
        ret                                     ; return eax

This is a simple loop iterating over all elements of num. In the loop's body there is a little trick though: from the original element ecx, a negated value is stored in edi. By using cmovl, edi is overwritten with the original value if that original value is positive. That means that edi will always turn out positive (i.e. contain the absolute value of the original element). Then it is added to eax (which is returned in the end).

So your if branch was hidden in the cmov instruction. As you can see in this benchmark, the time required to execute a cmov instruction is independent of the probability of the condition. It's a pretty amazing instruction!


SIMD algorithm

The SIMD version consists of quite a few instructions that I won't fully paste here. The main loop handles 16 integers at once!

        movdqu  xmm5, xmmword ptr [rdx + 4*rdi]
        movdqu  xmm3, xmmword ptr [rdx + 4*rdi + 16]
        movdqu  xmm0, xmmword ptr [rdx + 4*rdi + 32]
        movdqu  xmm1, xmmword ptr [rdx + 4*rdi + 48]

They are loaded from memory into the registers xmm0, xmm1, xmm3 and xmm5. Each of those registers contains four 32 bit values, but to follow along more easily, just imagine each register contains exactly one value. All following instructions operate on each value of those SIMD registers individually, so that mental model is fine! My explanation below will also sound as if xmm registers would only contain a single value.

The main trick is now in the following instructions (which handle xmm5):

        movdqa  xmm6, xmm5      ; xmm6 = xmm5 (make a copy)
        psrad   xmm6, 31        ; logical right shift 31 bits (see below)
        paddd   xmm5, xmm6      ; xmm5 += xmm6
        pxor    xmm5, xmm6      ; xmm5 ^= xmm6

The logical right shift fills the "empty high-order bits" (the ones "shifted in" on the left) with the value of the sign bit. By shifting by 31, we end up with only the sign bit in every position! So any positive number will turn into 32 zeroes and any negative number will turn into 32 ones. So xmm6 is now either 000...000 (if xmm5 is positive) or 111...111 (if xmm5 is negative).

Next this artificial xmm6 is added to xmm5. If xmm5 was positive, xmm6 is 0, so adding it won't change xmm5. If xmm5 was negative, however, we add 111...111 which is equivalent to subtracting 1. Finally, we xor xmm5 with xmm6. Again, if xmm5 was positive in the beginning, we xor with 000...000 which does not have an effect. If xmm5 was negative in the beginning we xor with 111...111, meaning we flip all the bits. So for both cases:

  • If the element was positive, we change nothing (the add and xor didn't have any effect)
  • If the element was negative, we subtracted 1 and flipped all bits. This is a two's complement negation!

So with these 4 instructions we calculated the absolute value of xmm5! Here again, there is no branch because of this bit-fiddling trick. And remember that xmm5 actually contains 4 integers, so it's quite speedy!

This absolute value is now added to an accumulator and the same is done with the three other xmm registers that contain values from the slice. (We won't discuss the remaining code in detail.)


SIMD with AVX2

If we allow LLVM to emit AVX2 instructions (via -C target-feature=+avx2), it can even use the pabsd instruction instead of the four "hacky" instructions:

vpabsd  ymm2, ymmword ptr [rdx + 4*rdi]

It loads the values directly from memory, calculates the absolute and stores it in ymm2 in one instruction! And remember that ymm registers are twice as large as xmm registers (fitting eight 32 bit values)!

Scully answered 4/1, 2020 at 10:39 Comment(4)
You might want to tell LLVM not to unroll loops so you can see what it's doing without getting bogged down in the unrolling. For clang the option is -fno-unroll-loops, but that option name might just be for GCC compat, not LLVM's own internal name. Also, if you let it use SSSE3 or AVX2, it will hopefully use pabsd to do SIMD absolute value in one instruction instead of needing the 2's complement identity -x = ~(x - 1) bithack.Cactus
@PeterCordes Thanks! I added some information about pabsd. With AVX2 the assembly is indeed much nicer.Scully
Too bad LLVM still uses an indexed addressing mode even when unrolling, so the instruction costs 2 fused-domain uop on Intel CPUs. :/ Micro fusion and addressing modes. It probably doesn't bottleneck on the front-end, though, even with data hot in L1d cache, with vpabsd [mem] + vpaddd only being a total of 3 fused-domain uops on Haswell/Skylake. (And the pipeline being 4-wide, so there's room for the loop overhead.)Cactus
Thanks for the great answer! I learned a bunch from that - and thanks for the link to Godbolt! Using it as a repl, I was able to get rid of the conditional move (and demonstrate the branch prediction failure penalty I was looking for) by just making the if body a little more complicated. e.g., converting total += nums[i] to something bigger like total += nums[i]*(nums[i]-1). I guess this could technically be done with conditional moves still, but the optimizer's heuristics just decide that branching is a better approach here?Hege

© 2022 - 2024 — McMap. All rights reserved.