Count leading zero bits for each element in AVX2 vector, emulate _mm256_lzcnt_epi32
Asked Answered
S

3

9

With AVX512, there is the intrinsic _mm256_lzcnt_epi32, which returns a vector that, for each of the 8 32-bit elements, contains the number of leading zero bits in the input vector's element.

Is there an efficient way to implement this using AVX and AVX2 instructions only?

Currently I'm using a loop which extracts each element and applies the _lzcnt_u32 function.


Related: to bit-scan one large bitmap, see Count leading zeros in __m256i word which uses pmovmskb -> bitscan to find which byte to do a scalar bitscan on.

This question is about doing 8 separate lzcnts on 8 separate 32-bit elements when you're actually going to use all 8 results, not just select one.

Skilken answered 12/11, 2019 at 16:46 Comment(0)
U
9

float represents numbers in an exponential format, so int->FP conversion gives us the position of the highest set bit encoded in the exponent field.

We want int->float with magnitude rounded down (truncate the value towards 0), not the default rounding of nearest. That could round up and make 0x3FFFFFFF look like 0x40000000. If you're doing a lot of these conversions without doing any FP math, you could set the rounding mode in the MXCSR1 to truncation then set it back when you're done.

Otherwise you can use v & ~(v>>8) to keep the 8 most-significant bits and zero some or all lower bits, including a potentially-set bit 8 below the MSB. That's enough to ensure all rounding modes never round up to the next power of two. It always keeps the 8 MSB because v>>8 shifts in 8 zeros, so inverted that's 8 ones. At lower bit positions, wherever the MSB is, 8 zeros are shifted past there from higher positions, so it will never clear the most significant bit of any integer. Depending on how set bits below the MSB line up, it might or might not clear more below the 8 most significant.

After conversion, we use an integer shift on the bit-pattern to bring the exponent (and sign bit) to the bottom and undo the bias with a saturating subtract. We use min to set the result to 32 if no bits were set in the original 32-bit input.

__m256i avx2_lzcnt_epi32 (__m256i v) {
    // prevent value from being rounded up to the next power of two
    v = _mm256_andnot_si256(_mm256_srli_epi32(v, 8), v); // keep 8 MSB

    v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float
    v = _mm256_srli_epi32(v, 23); // shift down the exponent
    v = _mm256_subs_epu16(_mm256_set1_epi32(158), v); // undo bias
    v = _mm256_min_epi16(v, _mm256_set1_epi32(32)); // clamp at 32

    return v;
}

Footnote 1: fp->int conversion is available with truncation (cvtt), but int->fp conversion is only available with default rounding (subject to MXCSR).

AVX512F introduces rounding-mode overrides for 512-bit vectors which would solve the problem, __m512 _mm512_cvt_roundepi32_ps( __m512i a, int r);. But all CPUs with AVX512F also support AVX512CD so you could just use _mm512_lzcnt_epi32. And with AVX512VL, _mm256_lzcnt_epi32

Unsavory answered 12/11, 2019 at 22:24 Comment(5)
This answer could use more explanation about why it works, specifically what the "round down" is doing and why. I assume the v & ~(v>>8) is because above 2^24 not all integers can be exactly represented, and we need to avoid rounding up during conversion, but why does that work for both large and small integers? I guess if v was small then there are more high zeros, so we're always keeping the 8 most-significant bits? And a float can always exactly represent a number with 8 significant bits.Ghat
Conceptually, clz(uint32_t a) = (a)?(158-(float_as_uint32 (uint32_to_float_rz (a))>>23)):32. If one uses regular conversion of uint32_t to float with round-to-nearest, the result may round up to the next power of two, giving incorrect clz count. Therefore the need to round towards zero ("rz") in the conversion. Not sure how to best perform this conversion in AVX, though.Ghislainegholston
@njuffa: without AVX512 for unsigned -> float conversions, probably just use signed conversion and handle the MSB being set using a vblendvps to special-case it. Or arithmetic right shift + ANDNOT to zero elements where the input had the MSB set, with only the ANDNOT on the critical path not both blend uops. As far as rounding, this answer's trick of clearing all but the 8 MSB should work well. Or if you had a lot of this to do, you could change the MXCSR rounding mode to truncation (toward zero). I made an edit to this answer to add that explanation.Ghat
Expression 'v & ~(v>>8)' to prevent possible exponent increasing due to rounding up isn't conceptually clean. Instead of 8 with the same success one could use a value from range [1..24]. The expression's core function isn't to keep 8 MSB, but clear bit 8 starting from MSB (MSB itself is bit 0). This stops carry propagation in mantissa at bit 8, so no exponent change possible. Obviously, we can set such 'hole' in any position within range [1..24]. I think conceptually it's more clean to use 'v & ~(v>>1)'. What I'm really interesting - does v>>1 consume less power then v>>2? ;)Shipload
And 158 looks better as 127 + 31 (31 reflects that we count leading zeros, while 127 is pure bias).Shipload
G
3

@aqrit's answer looks like a more-clever use of FP bithacks. My answer below is based on the first place I looked for a bithack which was old and aimed at scalar so it didn't try to avoid double (which is wider than int32 and thus a problem for SIMD).

It uses HW signed int->float conversion and saturating integer subtracts to handle the MSB being set (negative float), instead of stuffing bits into a mantissa for manual uint->double. If you can set MXCSR to round down across a lot of these _mm256_lzcnt_epi32, that's even more efficient.


https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogIEEE64Float suggests stuffing integers into the mantissa of a large double, then subtracting to get the FPU hardware to get a normalized double. (I think this bit of magic is doing uint32_t -> double, with the technique @Mysticial explains in How to efficiently perform double/int64 conversions with SSE/AVX? (which works for uint64_t up to 252-1)

Then grab the exponent bits of the double and undo the bias.

I think integer log2 is the same thing as lzcnt, but there might be an off-by-1 at powers of 2.

The Standford Graphics bithack page lists other branchless bithacks you could use that would probably still be better than 8x scalar lzcnt.

If you knew your numbers were always small-ish (like less than 2^23) you could maybe do this with float and avoid splitting and blending.

  int v; // 32-bit integer to find the log base 2 of
  int r; // result of log_2(v) goes here
  union { unsigned int u[2]; double d; } t; // temp

  t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] = 0x43300000;
  t.u[__FLOAT_WORD_ORDER!=LITTLE_ENDIAN] = v;
  t.d -= 4503599627370496.0;
  r = (t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] >> 20) - 0x3FF;

The code above loads a 64-bit (IEEE-754 floating-point) double with a 32-bit integer (with no paddding bits) by storing the integer in the mantissa while the exponent is set to 252. From this newly minted double, 252 (expressed as a double) is subtracted, which sets the resulting exponent to the log base 2 of the input value, v. All that is left is shifting the exponent bits into position (20 bits right) and subtracting the bias, 0x3FF (which is 1023 decimal).

To do this with AVX2, blend and shift+blend odd/even halves with set1_epi32(0x43300000) and _mm256_castps_pd to get a __m256d. And after subtracting, _mm256_castpd_si256 and shift / blend the low/high halves into place then mask to get the exponents.

Doing integer operations on FP bit-patterns is very efficient with AVX2, just 1 cycle of extra latency for a bypass delay when doing integer shifts on the output of an FP math instruction.

(TODO: write it with C++ intrinsics, edit welcome or someone else could just post it as an answer.)


I'm not sure if you can do anything with int -> double conversion and then reading the exponent field. Negative numbers have no leading zeros and positive numbers give an exponent that depends on the magnitude.

If you did want that, you'd go one 128-bit lane at a time, shuffling to feed xmm -> ymm packed int32_t -> packed double conversion.

Ghat answered 12/11, 2019 at 17:8 Comment(0)
G
2

The question is also tagged AVX, but there are no instructions for integer processing in AVX, which means one needs to fall back to SSE on platforms that support AVX but not AVX2. I am showing an exhaustively tested, but a bit pedestrian version below. The basic idea here is as in the other answers, in that the count of leading zeros is determined by the floating-point normalization that occurs during integer to floating-point conversion. The exponent of the result has a one-to-one correspondence with the count of leading zeros, except that the result is wrong in the case of an argument of zero. Conceptually:

clz (a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

where float_as_uint32() is a re-interpreting cast and uint32_to_float_rz() is a conversion from unsigned integer to floating-point with truncation. A normal, rounding, conversion could bump up the conversion result to the next power of two, resulting in an incorrect count of leading zero bits.

SSE does not provide truncating integer to floating-point conversion as a single instruction, nor conversions from unsigned integers. This functionality needs to be emulated. The emulation does not need to be exact, as long as it does not change the magnitude of the conversion result. The truncation part is handled by the invert - right shift - andn technique from aqrit's answer. To use signed conversion, we cut the number in half before the conversion, then double and increment after the conversion:

float approximate_uint32_to_float_rz (uint32_t a)
{
    float r = (float)(int)((a >> 1) & ~(a >> 2));
    return r + r + 1.0f;
}

This approach is translated into SSE intrinsics in sse_clz() below.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include "immintrin.h"

/* compute count of leading zero bits using floating-point normalization.

   clz(a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

   The problematic part here is uint32_to_float_rz(). SSE does not offer
   conversion of unsigned integers, and no rounding modes in integer to
   floating-point conversion. Since all we need is an approximate version
   that preserves order of magnitude:

   float approximate_uint32_to_float_rz (uint32_t a)
   {
      float r = (float)(int)((a >> 1) & ~(a >> 2));
      return r + r + 1.0f;
   }
*/  
__m128i sse_clz (__m128i a) 
{
    __m128 fp1 = _mm_set_ps1 (1.0f);
    __m128i zero = _mm_set1_epi32 (0);
    __m128i i158 = _mm_set1_epi32 (158);
    __m128i iszero = _mm_cmpeq_epi32 (a, zero);
    __m128i lsr1 = _mm_srli_epi32 (a, 1);
    __m128i lsr2 = _mm_srli_epi32 (a, 2);
    __m128i atrunc = _mm_andnot_si128 (lsr2, lsr1);
    __m128 atruncf = _mm_cvtepi32_ps (atrunc);
    __m128 atruncf2 = _mm_add_ps (atruncf, atruncf);
    __m128 conv = _mm_add_ps (atruncf2, fp1);
    __m128i convi = _mm_castps_si128 (conv);
    __m128i lsr23 = _mm_srli_epi32 (convi, 23);
    __m128i res = _mm_sub_epi32 (i158, lsr23);
    return _mm_sub_epi32 (res, iszero);
}

/* Portable reference implementation of 32-bit count of leading zeros */    
int clz32 (uint32_t a)
{
    uint32_t r = 32;
    if (a >= 0x00010000) { a >>= 16; r -= 16; }
    if (a >= 0x00000100) { a >>=  8; r -=  8; }
    if (a >= 0x00000010) { a >>=  4; r -=  4; }
    if (a >= 0x00000004) { a >>=  2; r -=  2; }
    r -= a - (a & (a >> 1));
    return r;
}

/* Test floating-point based count leading zeros exhaustively */
int main (void)
{
    __m128i res;
    uint32_t resi[4], refi[4];
    uint32_t count = 0;
    do {
        refi[0] = clz32 (count);
        refi[1] = clz32 (count + 1);
        refi[2] = clz32 (count + 2);
        refi[3] = clz32 (count + 3);
        res = sse_clz (_mm_set_epi32 (count + 3, count + 2, count + 1, count));
        memcpy (resi, &res, sizeof resi);
        if ((resi[0] != refi[0]) || (resi[1] != refi[1]) ||
            (resi[2] != refi[2]) || (resi[3] != refi[3])) {
            printf ("error @ %08x %08x %08x %08x\n",
                    count, count+1, count+2, count+3);
            return EXIT_FAILURE;
        }
        count += 4;
    } while (count);
    return EXIT_SUCCESS;
}
Ghislainegholston answered 13/11, 2019 at 2:25 Comment(3)
SSE does not provide truncating integer to floating-point conversion - not strictly true; you can change MXCSR with _mm_setcsr(unsigned int i). @aqrit's answer indicates that some compilers may mess this up, perhaps reordering it with SIMD intrinsics? Anyway, it's not the most efficient way either, if you only care about the exponent of the result.Ghat
Do you really need _mm_add_ps? Can you +0.5 instead of +1.0 and fold a +1 into the exponent bias correction? Also, as I commented on another answer, it might work better to do signed int->float conversion and then correct for negative inputs after the fact with a blend or result & (input>>31) arithmetic shift because negative integers have no leading zeros. (Or like aqrit is doing, with a saturating subtract). Anyway, upvoted for the proof of concept of this strategy even though it's almost certain slower than @aqrit's (a 128-bit version of aqrit's needs only SSE2 psubusw/pminw).Ghat
I think aqrit's answer is farther along the same path and better than anything I was suggesting; handling negative a by letting the int->FP conversion produce a non-zero sign bit. Handling that with a saturating subtract and min to clamp is about as good as a 2 uop blend, and I think my idea still needed more uops. Although if not, it could maybe shorten the critical path even if it didn't save uops.Ghat

© 2022 - 2024 — McMap. All rights reserved.