What's the fastest way to perform an arbitrary 128/256/512 bit permutation using SIMD instructions?
Asked Answered
G

2

13

I want to perform an arbitrary permutation of single bits, pairs of bits, and nibbles (4 bits) on a CPU register (xmm, ymm or zmm) of width 128, 256 or 512 bits; this should be as fast as possible. For this I was looking into SIMD instructions. Does anyone know of a way to do this/a library that implements it? I'm using MSVC on Windows and GCC on Linux, and the host language is C or C++. Thanks!

I'm given an arbitrary permutation and need to shuffle a large number of bit vectors/pairs of bit vectors/nibbles. I know how to do this for the bits within a 64 bit value, e.g. using a Benes network.

Or shuffling blocks of 8-bit and larger around on the wider SIMD registers, e.g. using Agner Fog's GPLed VectorClass library (https://www.agner.org/optimize/vectorclass.pdf) for a template metaprogramming function that builds shuffles out of AVX2 in-lane byte shuffles and/or larger-element lane-crossing shuffles, given the shuffle as template parameter.


A more granular subdivision for permutations - into 1, 2 or 4 bit blocks - seems to be hard to achieve across wide vectors, though.

I'm able to do pre-processing on the permutation, e.g. to extract bit masks, calculate indices as necessary e.g. for a Benes network, or whatever else - happy to do that in another high level language as well, so assume that the permutation is given in whatever format is most convenient to solve the problem; small-ish lookup tables included.

I would expect the code to be significantly faster than doing something like

// actually 1 bit per element, not byte.  I want a 256-bit bit-shuffle
const uint8_t in[256] = get_some_vector(); // not a compile-time constant
const uint8_t perm[256] = ...;             // compile-time constant
uint8_t out[256];
for (size_t i = 0; i < 256; i ++)
    out[i] = in[perm[i]];

As I said, I have a solution for <= 64 bits (which would be 64 bits, 32 bit-pairs, and 16 nibbles). The problem is also solved for blocks of size 8, 16, 32 etc. on wider SIMD registers.

EDIT: to clarify, the permutation is a compile-time constant (but not just one particular one, I'll compile the program once per permutation given).

Grasmere answered 28/1, 2019 at 19:9 Comment(5)
Why do you need a Benes network for large elements? AVX2 has _mm256_permutevar8x32_epi32 (vpermd), which take a control vector for an arbitrary lane-crossing permute. (For 64-bit indices, you need to preprocess that to create an input for vpermd, but that's just mapping each i to 2*i + 0..1 with maybe and add same,same to double, then an in-lane shuffle, then adding a vector of 0,1,0,1,...)Indemnity
Or since you mention 512-bit registers, AVX512F provides vpermq with a vector control, not just immediate compile-time constant. Are your permutes going to be compile-time constants? The code you show uses const in[] and perm[], so that's no help, the compiler can calculate the contents of out at compile time if both inputs are compile-time constants. And is that supposed to be 1 bit per uint8_t for a 256-bit register? Or do you really want a 256-Byte shuffle (8 YMM registers wide)?Indemnity
Thanks! in[] is not compile-time constant, but perm[] is - I'd be happy to give it in whatever form suitable (template parameters for a library, masks for assembly instructions and immediates). vpermd shuffles blocks of 32 bit integers across lanes, but not the bits within the integers. As mentioned (maybe not clearly enough), permutations of blocks of size >= 8 are taken care of. So to emphasize, I really want a 256 BIT shuffle, not BYTE---that is, maybe you tell me that's faster, and storing one bit/two bits/one nibble per bite and disregarding the rest is the way to go.Grasmere
The trivial (but impractical) solution would be to hand-code every possible permutation and select the right one via template specialization. I doubt that there is an easy solution which is optimal for every case (e.g., I'm sure for some permutations the most efficient way is a combination of multiplications and bit-operations -- and to determine that at compile-time sounds quite hard).Francyne
there's indeed a tool which finds an optimal bit shuffle network on programming.sirrida.de/bit_perm.html by trying them all out. However, the code works only for 64 bits; I'd also prefer a one-covers-it-all-case and I'm willing to sacrifice a few cycles for it. But valid point.Grasmere
N
9

The AVX2 256 bit permutation case

I do not think it is possible to write an efficient generic SSE4/AVX2/AVX-512 algorithm that works for all vector sizes (128, 256, 512 bits), and element granularities (bits, bit pairs, nibbles, bytes). One problem is that many AVX2 instructions that exist for, for example, byte size elements, do not exist for double word elements, and vice versa.

Below the AVX2 256 bit permutation case is discussed. It might be possible to recycle the ideas of this case for other cases.

The idea is to extract 32 (permuted) bits per step from input vector x. In each step 32 bytes from permutation vector pos are read. Bits 7..3 of these pos bytes determine which byte from x is needed. The right byte is selected by an emulated 256 bits wide AVX2 lane crossing byte shuffle coded here by Ermlg. Bits 2..0 of the pos bytes determine which bit is sought. With _mm256_movemask_epi8 the 32 bits are collected in one _uint32_t This step is repeated 8 times, to get all the 256 permuted bits.

The code does not look very elegant. Nevertheless, I would be surprised if a significantly faster, say two times faster, AVX2 method would exist.

/*     gcc -O3 -m64 -Wall -mavx2 -march=skylake bitperm_avx2.c     */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

inline __m256i shuf_epi8_lc(__m256i value, __m256i shuffle);
int print_epi64(__m256i  a);

uint32_t get_32_bits(__m256i x, __m256i pos){
    __m256i pshufb_mask  = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
    __m256i byte_pos     = _mm256_srli_epi32(pos, 3);                       /* which byte within the 32 bytes    */
            byte_pos     = _mm256_and_si256(byte_pos, _mm256_set1_epi8(0x1F)); /* mask off the unwanted bits */
    __m256i bit_pos      = _mm256_and_si256(pos, _mm256_set1_epi8(0x07));   /* which bit within the byte         */
    __m256i bit_pos_mask = _mm256_shuffle_epi8(pshufb_mask, bit_pos);       /* get bit mask                      */
    __m256i bytes_wanted = shuf_epi8_lc(x, byte_pos);                       /* get the right bytes               */
    __m256i bits_wanted  = _mm256_and_si256(bit_pos_mask, bytes_wanted);    /* apply the bit mask to get rid of the unwanted bits within the byte */
    __m256i bits_x8      = _mm256_cmpeq_epi8(bits_wanted, bit_pos_mask);    /* check if the bit is set           */        
            return _mm256_movemask_epi8(bits_x8);
}

__m256i get_256_bits(__m256i x, uint8_t* pos){ /* glue the 32 bit results together */
    uint64_t t0 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[0]));
    uint64_t t1 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[32]));
    uint64_t t2 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[64]));
    uint64_t t3 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[96]));
    uint64_t t4 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[128]));
    uint64_t t5 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[160]));
    uint64_t t6 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[192]));
    uint64_t t7 = get_32_bits(x, _mm256_loadu_si256((__m256i*)&pos[224]));
    uint64_t t10 = (t1<<32)|t0;
    uint64_t t32 = (t3<<32)|t2;
    uint64_t t54 = (t5<<32)|t4;
    uint64_t t76 = (t7<<32)|t6;
    return(_mm256_set_epi64x(t76, t54, t32, t10));
}


inline __m256i shuf_epi8_lc(__m256i value, __m256i shuffle){
/* Ermlg's lane crossing byte shuffle https://mcmap.net/q/907584/-shuffle-elements-of-__m256i-vector */
const __m256i K0 = _mm256_setr_epi8(
    0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70,
    0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0);
const __m256i K1 = _mm256_setr_epi8(
    0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
    0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70, 0x70);
return _mm256_or_si256(_mm256_shuffle_epi8(value, _mm256_add_epi8(shuffle, K0)), 
    _mm256_shuffle_epi8(_mm256_permute4x64_epi64(value, 0x4E), _mm256_add_epi8(shuffle, K1)));
}


int main(){
    __m256i    input = _mm256_set_epi16(0x1234,0x9876,0x7890,0xABCD, 0x3456,0x7654,0x0123,0x4567,
                                        0x0123,0x4567,0x89AB,0xCDEF, 0xFEDC,0xBA98,0x7654,0x3210);
/* Example                                                                                         */
/*            240  224  208  192    176  160  144  128    112   96   80   64     48   32   16    0 */                        
/* input     1234 9876 7890 ABCD | 3456 7654 0123 4567 | 0123 4567 89AB CDEF | FEDC BA98 7654 3210 */
/* output    0000 0000 0012 00FF | 90AB 3210 7654 ABCD | 8712 1200 FF90 AB32 | 7654 ABCD 1087 7654 */
    uint8_t permutation[256] = {16,17,18,19,     20,21,22,23,      24,25,26,27,     28,29,30,31,
                                28,29,30,31,     32,33,34,35,      0,1,2,3,         4,5,6,7,
                                72,73,74,75,     76,77,78,79,      80,81,82,83,     84,85,86,87,      
                                160,161,162,163, 164,165,166,167,  168,169,170,171, 172,173,174,175,  
                                8,9,10,11,       12,13,14,15,      200,201,202,203, 204,205,206,207,
                                208,209,210,211, 212,213,214,215,  215,215,215,215, 215,215,215,215,
                                1,1,1,1,         1,1,1,1,          248,249,250,251, 252,253,254,255,
                                248,249,250,251, 252,253,254,255,  28,29,30,31,     32,33,34,35,
                                72,73,74,75,     76,77,78,79,      80,81,82,83,     84,85,86,87,
                                160,161,162,163, 164,165,166,167,  168,169,170,171, 172,173,174,175,
                                0,1,2,3,         4,5,6,7,          8,9,10,11,       12,13,14,15,
                                200,201,202,203, 204,205,206,207,  208,209,210,211, 212,213,214,215,
                                215,215,215,215, 215,215,215,215,  1,1,1,1,         1,1,1,1,
                                248,249,250,251, 252,253,254,255,  1,1,1,1,         1,1,1,1,
                                1,1,1,1,         1,1,1,1,          1,1,1,1,         1,1,1,1,
                                1,1,1,1,         1,1,1,1,          1,1,1,1,         1,1,1,1};
               printf("input = \n");
               print_epi64(input);
    __m256i    x = get_256_bits(input, permutation);
               printf("permuted input = \n");
               print_epi64(x);
               return 0;
}


int print_epi64(__m256i  a){
    uint64_t  v[4];
    int i;
    _mm256_storeu_si256((__m256i*)v,a);
    for (i = 3; i>=0; i--) printf("%016lX  ",v[i]);
    printf("\n");
    return 0;
}

The output with the example permutation looks correct:

$ ./a.out
input = 
123498767890ABCD  3456765401234567  0123456789ABCDEF  FEDCBA9876543210  
permuted input = 
00000000001200FF  90AB32107654ABCD  87121200FF90AB32  7654ABCD10877654  

Efficiency

If you look carefully at the algorithm, you will see that some operations only depend on the permutation vector pos, and not on x. This means that the applying the permutation with a variable x, and a fixed pos, should be more efficient than applying the permutation with both variable x and pos.

This is illustrated by the following code:

/* apply the same permutation several times */
int perm_array(__m256i* restrict x_in, uint8_t* restrict pos, __m256i* restrict x_out){
    for (int i = 0; i<1024; i++){
            x_out[i]=get_256_bits(x_in[i], pos);
    }
    return 0;
}

With clang and gcc this compiles to really nice code: Loop .L5 at line 237 only contains 16 vpshufbs instead of 24. Moreover the vpaddbs are hoisted out of the loop. Note that there is also only one vpermq inside the loop.

I do not know if MSVC will hoist such many instructions outside the loop. If not, it might be possible to improve the performance of the loop by modifying the code manually. This should be done such that the operations which only depend on pos, and not on x, are hoisted outside the loop.

With respect to the performance on Intel Skylake: The throughput of this loop is likely limited by the about 32 port 5 micro-ops per loop iteration. This means that the throughput in a loop context such as perm_array is about 256 permuted bits per 32 CPU cycles, or about 8 permuted bits per CPU cycle.


128 bit permutations using AVX2 instructions

This code is quite similar to the 256 bit permutation case. Although only 128 bits are permuted, the full 256 bit width of the AVX2 registers is used to achieve the best performance. Here the byte shuffles are not emulated. This is because there exists an efficient single instruction to do the byte shuffling within the 128 bit lanes: vpshufb.

Function perm_array_128 tests the performance of the bit permutation for a fixed permutation and a variable input x. The assembly loop contains about 11 port 5 (p5) micro-ops, if we assume an Intel Skylake CPU. These 11 p5 micro-ops take at least 11 CPU cycles (throughput). So, in the best case we get a throughput of about 12 permuted bits per cycle, which is about 1.5 times as fast as the 256 bit permutation case.

/*     gcc -O3 -m64 -Wall -mavx2 -march=skylake bitperm128_avx2.c     */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

int print128_epi64(__m128i  a);

uint32_t get_32_128_bits(__m256i x, __m256i pos){                           /* extract 32 permuted bits out from 2x128 bits   */
    __m256i pshufb_mask  = _mm256_set_epi8(0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1, 0,0,0,0, 0,0,0,0, 128,64,32,16, 8,4,2,1);
    __m256i byte_pos     = _mm256_srli_epi32(pos, 3);                       /* which byte do we need within the 16 byte lanes. bits 6,5,4,3 select the right byte */
            byte_pos     = _mm256_and_si256(byte_pos, _mm256_set1_epi8(0xF)); /* mask off the unwanted bits (unnecessary if _mm256_srli_epi8 would have existed   */
    __m256i bit_pos      = _mm256_and_si256(pos, _mm256_set1_epi8(0x07));   /* which bit within the byte                 */
    __m256i bit_pos_mask = _mm256_shuffle_epi8(pshufb_mask, bit_pos);       /* get bit mask                              */
    __m256i bytes_wanted = _mm256_shuffle_epi8(x, byte_pos);                /* get the right bytes                       */
    __m256i bits_wanted  = _mm256_and_si256(bit_pos_mask, bytes_wanted);    /* apply the bit mask to get rid of the unwanted bits within the byte */
    __m256i bits_x8      = _mm256_cmpeq_epi8(bits_wanted, bit_pos_mask);    /* set all bits if the wanted bit is set     */        
            return _mm256_movemask_epi8(bits_x8);                           /* move most significant bit of each byte to 32 bit register */
}


__m128i permute_128_bits(__m128i x, uint8_t* pos){      /* get bit permutations in 32 bit pieces and glue them together */
    __m256i  x2 = _mm256_broadcastsi128_si256(x);   /* broadcast x to the hi and lo lane                            */
    uint64_t t0 = get_32_128_bits(x2, _mm256_loadu_si256((__m256i*)&pos[0]));
    uint64_t t1 = get_32_128_bits(x2, _mm256_loadu_si256((__m256i*)&pos[32]));
    uint64_t t2 = get_32_128_bits(x2, _mm256_loadu_si256((__m256i*)&pos[64]));
    uint64_t t3 = get_32_128_bits(x2, _mm256_loadu_si256((__m256i*)&pos[96]));
    uint64_t t10 = (t1<<32)|t0;
    uint64_t t32 = (t3<<32)|t2;
    return(_mm_set_epi64x(t32, t10));
}

/* Test loop performance with the following loop (see assembly) -> 11 port5 uops inside the critical loop */
/* Use gcc -O3 -m64 -Wall -mavx2 -march=skylake -S bitperm128_avx2.c to generate the assembly             */
int perm_array_128(__m128i* restrict x_in, uint8_t* restrict pos, __m128i* restrict x_out){
    for (int i = 0; i<1024; i++){
            x_out[i]=permute_128_bits(x_in[i], pos);
    }
    return 0;
}


int main(){
    __m128i    input = _mm_set_epi16(0x0123,0x4567,0xFEDC,0xBA98,  0x7654,0x3210,0x89AB,0xCDEF);
/* Example                                                                                         */
/*             112   96   80   64     48   32   16    0 */                        
/* input      0123 4567 FEDC BA98   7654 3210 89AB CDEF */
/* output     8FFF CDEF DCBA 08EF   CDFF DCBA EFF0 89AB */
    uint8_t permutation[128] = {16,17,18,19,     20,21,22,23,      24,25,26,27,     28,29,30,31,
                                32,32,32,32,     36,36,36,36,      0,1,2,3,         4,5,6,7,
                                72,73,74,75,     76,77,78,79,      80,81,82,83,     84,85,86,87,      
                                0,0,0,0,         0,0,0,0,          8,9,10,11,       12,13,14,15,      
                                0,1,2,3,         4,5,6,7,          28,29,30,31,     32,33,34,35,
                                72,73,74,75,     76,77,78,79,      80,81,82,83,     84,85,86,87,
                                0,1,2,3,         4,5,6,7,          8,9,10,11,       12,13,14,15,
                                1,1,1,1,         1,1,1,1,          1,1,1,1,         32,32,32,1};
               printf("input = \n");
               print128_epi64(input);
    __m128i    x = permute_128_bits(input, permutation);
               printf("permuted input = \n");
               print128_epi64(x);
               return 0;
}


int print128_epi64(__m128i  a){
  uint64_t  v[2];
  int i;
  _mm_storeu_si128((__m128i*)v,a);
  for (i = 1; i>=0; i--) printf("%016lX  ",v[i]);
  printf("\n");
  return 0;
}

Example output for some arbitrary permutation:

$ ./a.out
input = 
01234567FEDCBA98  7654321089ABCDEF  
permuted input = 
8FFFCDEFDCBA08EF  CDFFDCBAEFF089AB  
Natal answered 29/1, 2019 at 12:20 Comment(7)
The same idea can be used to permute nibbles from a 256 bit wide AVX2 register. There are 64 nibbles, so we have 2 steps. In each step 32 bytes are permuted (lane crossing) based on bits 5..1 of the permutation index. Such a byte contains a high and a low nibble. Bit 0 of the permutation index is used as selector for a vpblendvb between this high and low nibble. This gives 32 nibbles. In the next step another 32 nibbles are selected from the 256 bit register. Some simple bit/byte manipulations lead eventually to the result.Natal
This is amazing! Thank you so much, that's already extremely helpful, and well-explained! I can also see how one would then do a two-bit shuffle efficiently: just store the even and odd bits in an alternating fashion in two 256 bit wide AVX2 registers and permute twice. Neat. Would you have any pointers as to how one would go about 128 bits?Grasmere
Great it works! With respect to 128 bits: are you interested in a 128 bit permutation using SSEx or using AVX2? Although the output is only 128 bits, the computations may benefit (better throughput) from using 256 bits wide AVX2 registers and instructions. An advantage of the 128 bit case is that there is no need to emulate the 256 bit wide byte shuffle, because the 128 bit wide byte shuffle is only one instruction: pshufb.Natal
AVX2 is perfectly fine! Once I have the bit that uses this implemented I'll post some real-world benchmarks.Grasmere
@JBausch 128 bit case added.Natal
neat, thanks so much! I'll accept this answer as you're also explaining things really well, much appreciated; if anyone knows the 512 bit case (using AVX512, or maybe a sped-up version for a 256 or 128 bit shuffle using AVX512 - note that e.g. a 64 bit shuffle can be done with one AVX512 vpermb to my knowledge) feel free to go ahead anyways, I feel like this question is a useful resource not only for me. Once I have this incorporated in my use case I'll update the question with some benchmarks.Grasmere
AVX-512 is very useful here indeed. I don't have the hardware yet, but it should be possible to recycle the ideas in this answer for AVX-512. Nice question.Natal
P
1

AVX2

The above answer is very good, but we can do a little better if the data is in memory:

void fill_avx2_perm_table(__m256i table[24], uint8_t idx[256]) {
    __m256i bit_idx_mask = _mm256_set1_epi8(0x7);
    __m256i byte_idx_mask = _mm256_set1_epi8(0xf);
    __m256i bit_mask_lookup = _mm256_set1_epi64x(0x8040201008040201);
    __m256i mask_out = _mm256_set1_epi8(-1);

    for (int i = 0; i < 8; ++i) {
        __m256i perm_32 = _mm256_loadu_si256((const __m256i*)(idx + 32 * i)); // 1 bit set -> comes from bits 127 .. 255
        __m256i shuf = _mm256_and_si256(_mm256_srli_epi32(perm_32, 3), byte_idx_mask);

        __m256i shuf_lo = _mm256_blendv_epi8(shuf, mask_out, perm_32);
        __m256i shuf_hi = _mm256_blendv_epi8(mask_out, shuf, perm_32);
        __m256i bit_mask = _mm256_shuffle_epi8(bit_mask_lookup, _mm256_and_si256(byte_idx_mask, perm_32));

        _mm256_store_si256(table + 2*i, shuf_lo);
        _mm256_store_si256(table + 2*i + 1, shuf_hi);
        _mm256_store_si256(table + 16 + i, bit_mask);
    }
}

void permute_256_array(char* arr, size_t len, uint8_t idx[256]) {
    __m256i perm_table[16 /* shuffles */ + 8 /* bit masks */];

    fill_avx2_perm_table(perm_table, idx);
    __m256i zero = _mm256_setzero_si256();

    char* end = arr + len * 32;
    for (; arr < end; arr += 32) {
        __m256i lo = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*) arr));
        __m256i hi = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*) arr + 1));
        __m256i lo_source, hi_source, bit_mask, bits;

        uint32_t result;

#define DO_ITER(i)  \
            lo_source = _mm256_shuffle_epi8(lo, _mm256_loadu_si256(perm_table + 2 * i)); \
            hi_source = _mm256_shuffle_epi8(hi, _mm256_loadu_si256(perm_table + 2 * i + 1)); \
\
            bit_mask = _mm256_loadu_si256(perm_table + 16 + i); \
            bits = _mm256_and_si256(_mm256_or_si256(lo_source, hi_source), bit_mask); \
            bits = _mm256_cmpeq_epi8(bits, bit_mask); \
\
            result = _mm256_movemask_epi8(bits); \
            memcpy(arr + 4 * i, &result, 4);

        DO_ITER(0) DO_ITER(1) DO_ITER(2) DO_ITER(3) DO_ITER(4) DO_ITER(5) DO_ITER(6) DO_ITER(7)
#undef DO_ITER
    }
}

Unlike vpermq, broadcasting 128 bits from memory to all 256 bits of a ymm register does not use a shuffle µop. Unfortunately, compilers sometimes seem to pessimize by inserting the vpmovmskb results into a vector register before storing them, and this can be seen in the accepted answer's compiler output. One can either insert asm memory clobbers (asm volatile ("" ::: "memory");) between stores, which unfortunately defeats instruction reordering, or (in my application) linking an assembly routine.

On Cascade Lake with -march=core-avx2, I get ~12.3 bits / cycle.

AVX512BW / VBMI

We can do a bit better with AVX512. Most of the examples will use the following function, similar to the first few lines of get_32_128_bits in the original AVX2 answer.

const __m512i bit_idx_mask = _mm512_set1_epi8(0x7);
const __m512i bit_mask_lookup = _mm512_set1_epi64(0x8040201008040201);

void get_permute_constants(__m512i idx, __m512i* byte_idx, __m512i* bit_mask) {
    *byte_idx = _mm512_srli_epi32(_mm512_andnot_si512(bit_idx_mask, idx), 3);
    idx = _mm512_and_si512(idx, bit_idx_mask);
    *bit_mask = _mm512_shuffle_epi8(bit_mask_lookup, idx);
}

For loop-invariant permutations, these values should be hoisted rather than repeatedly computed.

64-bit

In the special case of 64 bits and smaller, we can use vpshufbitqmb to select 64 bits into a mask register and store directly to memory with kmovq m64, k. If vpshufbitqmb is not available, we can emulate it with a byte shuffle followed by vptestmb with a mask of all the needed bits. (vptestmb computes the logical AND of two vector registers and, for every nonzero byte, writes a 1 to a mask register.)

#ifdef __AVX512BITALG__
void permute_64(uint64_t* in, uint64_t* out, __m512i idx) { 
    __m512i in_v = _mm512_set1_epi64(*in); 
    __mmask64 permuted_bits = _mm512_bitshuffle_epi64_mask(in_v, idx); 
    _store_mask64((__mmask64*) out, permuted_bits); 
} 
#else
void permute_64(uint64_t* in, uint64_t* out, __m512i idx) {
    __m512i byte_idx, bit_mask;
    get_permute_constants(idx, &byte_idx, &bit_mask);

    __m512i in_v = _mm512_set1_epi64(*in);
    __m512i in_shuffled = _mm512_shuffle_epi8(in_v, byte_idx);
    __mmask64 permuted_bits = _mm512_test_epi8_mask(in_shuffled, bit_mask);

    _store_mask64((__mmask64*) out, permuted_bits);
}
#endif

The workhorse as part of a loop:

1ed8:       62 f2 fd 48 59 00       vpbroadcastq zmm0,QWORD PTR [rax]
1ede:       48 83 c0 08             add    rax,0x8
1ee2:       62 f2 7d 48 8f c1       vpshufbitqmb k0,zmm0,zmm1
1ee8:       c4 e1 f8 91 40 f8       kmovq  QWORD PTR [rax-0x8],k0
1eee:       4c 39 e8                cmp    rax,r13
1ef1:       75 e5                   jne    1ed8

The observed performance on Ice Lake is ~1.38 cycles / element, or ~46 bits / cycle (with some error because I can't access performance counters). The fallback implementation runs at 2 cycles / element, or 32 bits / cycle, because vptestmb and vpshufb compete for the same port.

128-bit

void permute_128(const char* in, char* out,
    __m512i byte_idx1, __m512i bit_mask1,
    __m512i byte_idx2, __m512i bit_mask2) {
    __m512i in_v = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*) in));

    __mmask64 permuted_1 = _mm512_test_epi8_mask(
        _mm512_shuffle_epi8(in_v, byte_idx1), bit_mask1);
    __mmask64 permuted_2 = _mm512_test_epi8_mask(
        _mm512_shuffle_epi8(in_v, byte_idx2), bit_mask2);

    _store_mask64((__mmask64*) out, permuted_1);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 1, permuted_2);
}

void permute_128_array(char* arr, size_t count, uint8_t idx[128]) {
    __m512i idx1 = _mm512_loadu_si512(idx);
    __m512i idx2 = _mm512_loadu_si512(idx + 64);

    __m512i byte_idx1, bit_mask1, byte_idx2, bit_mask2;
    get_permute_constants(idx1, &byte_idx1, &bit_mask1);
    get_permute_constants(idx2, &byte_idx2, &bit_mask2);

    count *= 16;

    for (size_t i = 0; i < count; i += 16)
        permute_128(arr + i, arr + i, byte_idx1, bit_mask1, byte_idx2, bit_mask2);
}

The loop runs consistently at 4 cycles / element, or 32 bits / cycle, with the same bottleneck as before (vpshufb / vptestmb).

256-bit

128-bit-lane-crossing byte-granularity shuffles (vpermb and friends) were only added in AVX512VBMI (Ice Lake and later; Zen 4). The AVX512BW version is not much different from the AVX2 version: the 256-bit input is broadcasted in two 128-bit chunks, shuffled using merge-masked vpshufb.

void fill_avx512_perm_table(__m512i table[8], uint8_t idx[256], __mmask64 keeps[4]) {
    for (int i = 0; i < 4; ++i) {
        __m512i perm_64 = _mm512_loadu_si512(idx + 64 * i);
        get_permute_constants(perm_64, table + i, table + i + 4);

        if (keeps) // we'll use this function later w/o needing masks
            keeps[i] = _mm512_movepi8_mask(perm_64);  // hi bit -> 128 .. 255
    }
}

void permute_256(const char* in, char* out, __m512i perm_table[8], __mmask64 keeps[4]) {
    __m512i in_v1 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*) in));
    __m512i in_v2 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*) in + 1));
 
    __m512i ones = _mm512_set1_epi8(1);

#define DO_INTEL_OPT 0

#define DO_ITER(i, mask_reg) __mmask64 mask_reg; { \
    __m512i shuffle = perm_table[i]; \
    __m512i perm = _mm512_shuffle_epi8(in_v1, shuffle); \
    perm = _mm512_mask_shuffle_epi8(perm, keeps[i], in_v2, shuffle); \
    if (DO_INTEL_OPT) { \
        perm = _mm512_andnot_si512(perm, perm_table[i + 4]); \
        perm = _mm512_sub_epi8(perm, ones); \
        mask_reg = _mm512_movepi8_mask(perm); \
    } else mask_reg = _mm512_test_epi8_mask(perm, perm_table[i + 4]); \
}

    DO_ITER(0, permuted_1) DO_ITER(1, permuted_2) DO_ITER(2, permuted_3) DO_ITER(3, permuted_4)

    _store_mask64((__mmask64*) out, permuted_1);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 1, permuted_2);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 2, permuted_3);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 3, permuted_4);
#undef DO_ITER
#undef DO_INTEL_OPT
}

void permute_256_array(char* arr, size_t len, uint8_t idx[256]) {
    __m512i table[8];
    __mmask64 keeps[4];
    fill_avx512_perm_table(table, idx, keeps);

    char* end = arr + 32 * len;
    for (; arr < end; arr += 32)
        permute_256(arr, arr, table, keeps);
}

I get 12 cycles / 256-bit element (21 bits / cycle), which isn't even twice as fast as the AVX2-only version, because comparing into mask registers competes with the shuffles. The only vector-to-mask operation that doesn't use port 5 appears to be vpmovb2m and the like. There doesn't seem to be a profitable way to get the relevant bits to the HSB in each byte using a single instruction (no variable byte shifts), but we can do it in two—enabled above by DO_INTEL_OPT—which improves efficiency to theoretically 10 cycles / element (25.6 bits / cycle). This may or may not be worth the complexity, and I'm sure is strictly worse on AMD.

The VBMI solution is unremarkable and runs at 32 bits / cycle:

void permute_256(const char* in, char* out, __m512i perm_table[8], __mmask64 keeps[4] /* unused */) {
    __m512i in_v = _mm512_broadcast_i64x4(_mm256_loadu_si256((const __m256i*) in));

#define DO_ITER(i, mask_reg) __mmask64 mask_reg; { \
    __m512i shuffle = perm_table[i]; \
    __m512i perm = _mm512_permutexvar_epi8(shuffle, in_v); \
    mask_reg = _mm512_test_epi8_mask(perm, perm_table[i + 4]); \
}

    DO_ITER(0, permuted_1) DO_ITER(1, permuted_2) DO_ITER(2, permuted_3) DO_ITER(3, permuted_4)

    _store_mask64((__mmask64*) out, permuted_1);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 1, permuted_2);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 2, permuted_3);
    asm volatile ("" ::: "memory");
    _store_mask64((__mmask64*) out + 3, permuted_4);
#undef DO_ITER
}

512-bit

I didn't need such a wide permutation, but it wouldn't be that hard to do with VBMI. The only wrinkle is that you'd have to store the permutation indices as 16 bits each, but once you'd converted it to byte + bit index form, it should be the same as the previous solution—just replace the vbroadcasti64x4 load with a plain old 512-bit vmovdqu.

Polity answered 14/6, 2023 at 1:45 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.