How to find the first nonzero in an array efficiently?
Asked Answered
D

2

9

Suppose we want to quickly find the index of the first nonzero element in an array, to the effect of

fn leading_zeros(arr: &[u32]) -> Option<usize> {
    arr.iter().position(|&x| x != 0)
}

However, this gets compiled to the one by one check by rustc as seen here. One can speed this up a little bit by checking the words 4 by 4 using the u128 type as follows. This gives a speed up of roughly 3x on my machine.

fn leading_zeros_wide(arr: &[u32]) -> Option<usize> {
    let (beg, mid, _) = unsafe { arr.align_to::<u128>() };

    beg.iter().position(|&x| x != 0).or_else(|| {
        let left = beg.len() + 4 * mid.iter().position(|&x| x != 0).unwrap_or(mid.len());
        arr[left..].iter().position(|&x| x != 0).map(|p| p + left)
    })
}

Is there a way to make this even faster?


Here is a bench I've used to determine the 3x speedup:

#![feature(test)]
extern crate test;

fn v() -> Box<[u32]> {
    std::iter::repeat(0).take(1000).collect()
}

// Assume `leading_zeros` and `leading_zeros_wide` are defined here.

#[bench]
fn bench_leading_zeros(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros(&v[3..]))
}

#[bench]
fn bench_leading_zeros_wide(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros_wide(&v[3..]))
}
Dictatorial answered 27/12, 2021 at 18:3 Comment(12)
@Stargateur The line starting with let left skips the zeros 4 by 4, by interpreting adjacent 4 words as a single u128. If we cannot skip zeros this way, we fall back to scanning one by one.Dictatorial
@JohnKugelman I didn't use the end parameter because the slice arr[left..] contains that partDictatorial
@JohnKugelman mm with your explanation I see why end is ignored, I think the code deserve some annotation or better naming of variable. look ok for me now. that said since this question said this is faster, I think having benchmark code to test it in the question would be a plus if not a requirement.Propraetor
I think docs.rs/memx/latest/memx/fn.memnechr.html should be faster and more reliablePropraetor
Like already said here, user memchr(). Other than that, in similar cases, use SIMD.Heteroecious
Thanks all! Sadly, memx crate appears to have a bug at the moment for memnechr (at least for 0.1.18)Dictatorial
I see that your optimized version still not SIMD even when specified SIMD features as compiler options: rust.godbolt.org/z/8scnKToq8 it means that it can be optimized further. Apparently there is a way to use CPU intrinsics directly: x86, arm. Sorry, I will not provide this solution, I don't know Rust (I see this question by [simd] tag)Trilbi
I don't know how to use SIMD intrinsics in Rust, but the asm instructions you want it to emit on x86 are to search for a vector containing a non-zero element, then Is there an efficient way to get the first non-zero element in an SIMD register using SIMD intrinsics? to find position in that vector. Like my AVX2 C intrinsics answer on Efficiently find least significant set bit in a large array? (which does a bit-scan on the non-zero element once it finds it, to find the bit-position)Lucia
@PeterCordes I'm not worried about the 'last mile', that is finding the u32 inside a chunk. I just couldn't get rustc to vectorize the middle part, which is going to make the most impactDictatorial
If you need to manually vectorize anyway, you should definitely do it in a way that gets the element index efficiently. Some ways of looking for a non-zero vector on x86 involve pcmpeqd / movmskps anyway, so you already have the compare-result bitmap in an integer register just waiting for a bit-scan.Lucia
You probably do need to manually vectorize; LLVM and GCC's optimizers don't know how to auto-vectorize loops whose trip-count isn't known ahead of the first iteration. (i.e. search loops or other early-out conditions defeat them.) ICC can auto-vectorize such loops. You could maybe get something from portable code by unconditionally doing 4 u32 elements in an inner loop, but it's probably hard to get rustc to spit out a simple pcmpeqd / pmovmskb, rather than some silly horizontal reduction.Lucia
As you want to run this on aarch64 Please clarify if you want aarch64-specific solution, or a generic solution that does not use architecture-specific intrinsicTrilbi
P
4

64 bit: https://rust.godbolt.org/z/rsxh8P8Er

32 bit: https://rust.godbolt.org/z/3P3ejsnh1

I have a little experience with Rust and Assembly but I added some tests.

#[cfg(target_feature = "avx2")]
pub mod avx2 {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    fn first_nonzero_tiny(arr: &[u32]) -> Option<usize> {
        arr.iter().position(|&x| x != 0)
    }

    fn find_u32_zeros_8elems(arr: &[u32], offset: isize) -> i32 {
        unsafe {
            let ymm0 = _mm256_setzero_si256();
            let mut ymm1 = _mm256_loadu_si256(arr.as_ptr().offset(offset) as *const __m256i);
            ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
            let ymm2 = _mm256_castsi256_ps(ymm1);
            _mm256_movemask_ps(ymm2)
        }
    }

    pub fn first_nonzero(arr: &[u32]) -> Option<usize> {
        let size = arr.len();
        if size < 8 {
            return first_nonzero_tiny(arr);
        }

        let mut i: usize = 0;
        let simd_size = size / 8 * 8;
        while i < simd_size {
            let mask: i32 = find_u32_zeros_8elems(&arr, i as isize);
            //println!("mask = {}", mask);
            if mask != 255 {
                return Some((mask.trailing_ones() as usize) + i);
            }
            i += 8;
            //println!("i = {}", i);
        }

        let last_chunk = size - 8;
        let mask: i32 = find_u32_zeros_8elems(&arr, last_chunk as isize);
        if mask != 255 {
            return Some((mask.trailing_ones() as usize) + last_chunk);
        }

        None
    }
}

use avx2::first_nonzero;

pub fn main() {
    let v = [0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [2];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [1, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [0, 1, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(1));

    let v = [0, 0, 1, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(2));

    let v = [0, 0, 0, 1, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(3));

    let v = [0, 0, 0, 0, 1, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(4));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(6));

    let v = [0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(7));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(8));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(16));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(15));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 3, 4, 5];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(14));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(17));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(18));
}
Poinciana answered 30/12, 2021 at 23:33 Comment(15)
Looks good. It should be possible to handle the tail with SIMD, too, for any size>=8, with the last vector being an unaligned load that ends at the end of the array. (_mm256_loadu_si256; surprised you use alignment-required load in the loop without commenting on that input requirement). (It might be convenient to wrap the SIMD stuff in a helper function, instead of the scalar fallback).Lucia
It would be more convenient to handle that if you use i += 8 and use it as an offset to a pointer-to-i32 before casting the result to __m256i*, like C-style _mm_loadu_si128( (const __m128i*) &arr[i] ) instead of i + (const __m128i*)arr. Even for your current code, that would let you use the ending i value, instead of n*8 .. arr.len(). Although then you'd have to do i < n-7; you're solving that vector overshoot problem with n/=8; so at this point it's just two equally good styles of doing array indexing for manual SIMD.Lucia
Thank you @PeterCordes , I edited my answer and now it uses _mm256_loadu_si256 instead of _mm256_load_si256Poinciana
@PeterCordes I refactored my code with your suggestions.Poinciana
Nice, that saved an instruction inside the loop, now it's only incrementing i, not also updating some other counter for the benefit of code after the loop. A more meaningful name for the helper function might be find_u32_zeros_8elems - assigning the result to a variable called mask is sufficient reminder that it's a compare bitmask.Lucia
rustc's asm isn't really optimal, but fairly reasonable. e.g. could save a uop with a memory-source compare (but only if it avoided an indexed addressing mode for Intel SnB-family CPUs, which would mean more work to calc i outside the loop after the pointer increment). And it could avoid duplicating the pcmpeq/movmskps stuff, sharing that between both ways out of the function if it planned regs for it. I was hoping it would use tzcnt over bsf when BMI was available (target-feature=+avx2,bmi,bmi2), but no: rust.godbolt.org/z/r5r7z6dG6.Lucia
Interesting, -C target-cpu=haswell (rust.godbolt.org/z/cP4zn5hx1) does get it to use tzcnt. But still the useless movzx after movmskps, not even going to a different register, defeating mov-elimination. But it does tzcnt (or bsf) into a different register, causing a false-dependency on RDX, which that path of execution through the function didn't previously write. (BSF always has an output dependency so it can leave dst unmodified on input=0; TZCNT does on SnB-family before Skylake). Anyway, those are rustc / LLVM missed optimizations, nothing you can fix in the src.Lucia
I renamed helper function. And now I know that bmi, bmi2 are not meant implicitly when specifying avx2Poinciana
@PeterCordes, apparently target-feature=+avx2,bmi,bmi2 doesn't enable tzcnt, you have to use target-feature=+avx2,+bmi (bmi2 doesn't seem to be needed here)Trilbi
@AlexGuteniev: Ah, thanks. I should have checked the manual instead of guessing syntax. (And yes, tzcnt is part of BMI1; I threw in BMI2 to see if shrx and so on gave rustc anything useful to play with.)Lucia
@IgorZhukov Thank you so much! Will follow up with bench results on my machine (which will take time since I need to port it to aarch64)Dictatorial
@MertSağlam, oh. Porting this to aarch64 means solve this against for aarch64, as manual vectorization implies cpu architecture specific intrinsics. We thought you need x86 due to Godbolt link showing x86. I don't know if it can show aarch64 thoughTrilbi
@Alex Guteniev, It can show aarch64 assembly but I have ever less experience with aarch64 assembly :( Example: rust.godbolt.org/z/MWvnK1azzPoinciana
I found an example at github which uses rust aarch64 intrisics: github.com/3andne/trojan-oxide/blob/main/src/simd/simd_parse.rs And I managed to compile it at godbolt: rust.godbolt.org/z/WhhqEzrxqPoinciana
@MertSağlam I did a port: rust.godbolt.org/z/K9GW6nxds But I didn't test it. It could be completely broken, but maybe it saves some time for you...Poinciana
D
1

Here is a solution, which is faster than the baseline, but probably still leave a ton on the table.

The following achieves a 7.5x over the baseline first_nonzero.

/// Finds the position of the first nonzero element in a given slice which
/// contains a nonzero.
///
/// # Safety
///
/// The caller *has* to ensure that the input slice has a nonzero.
unsafe fn first_nonzero_padded(arr: &[u32]) -> usize {
    let (beg, mid, _) = arr.align_to::<u128>();
    beg.iter().position(|&x| x != 0).unwrap_or_else(|| {
        let left = beg.len()
            + 4 * {
                let mut p: *const u128 = mid.as_ptr();
                loop {
                    if *p.offset(0) != 0 { break p.offset(0); }
                    if *p.offset(1) != 0 { break p.offset(1); }
                    if *p.offset(2) != 0 { break p.offset(2); }
                    if *p.offset(3) != 0 { break p.offset(3); }
                    if *p.offset(4) != 0 { break p.offset(4); }
                    if *p.offset(5) != 0 { break p.offset(5); }
                    if *p.offset(6) != 0 { break p.offset(6); }
                    if *p.offset(7) != 0 { break p.offset(7); }
                    p = p.offset(8);
                }.offset_from(mid.as_ptr()) as usize
            };
        if let Some(p) = arr[left..].iter().position(|&x| x != 0) {
            left + p
        } else {
            core::hint::unreachable_unchecked()
        }
    })
}
Dictatorial answered 30/12, 2021 at 15:47 Comment(3)
Is there a way to compile (on Godbolt) the attempted SIMD version from the first revision of this answer? with use core_simd::u64x2; and so on? godbolt.org/z/E6ozdhdYc didn't work for me with rustc nightly. If it was slower, very likely your mask8x8::from_array([ *p.offset(00) != ZERO, up to 07 ]) didn't compile to a single SSE4.1 pcmpeqq or whatever. IDK if that would spend a lot of scalar work packing 8x 2-bit compare results into a single mask8x8, or worse booleanizing those 2-bit results into 1-bit results?Lucia
But anyway, describing that as pcmpeqd / tzcnt is almost certainly bogus, so yeah no wonder you deleted it from your answer :P And I'm not surprised an early-out on 16-byte chunks is a bit better; you want the inner loop to not spend a ton of work preparing for the stuff after the loop that sorts out where the non-zero element was. e.g. if you expect long runs of zeros, you can even OR together multiple vectors, then re-check them individually later. (Working in cache-line sized chunks is good, especially if your data is aligned by 64)Lucia
Your current code is doing scalar OR of two 64-bit chunks, branching on FLAGS set by that. godbolt.org/z/6fMEvveMbLucia

© 2022 - 2024 — McMap. All rights reserved.