AVX-512BW emulation of _mm512_dpbusd_epi32 AVX-512VNNI instruction
Asked Answered
I

2

3

There are AVX-512 VNNI instructions starting since Cascade Lake Intel CPU which can accelerate inference of quantized neural networks on CPU. In particular there is a instuction _mm512_dpbusd_epi32 (vpdpbusd) which allows to perform multiplication of 8-bit signed and unsigned integers and accumulate them into 32-bit integer accumulators. There is a pseudo code of this instruction below:

void _mm512_dpbusd_epi32(int32_t sum[16], uint8_t a[16][4], int8_t b[16][4])
{
    for(int i = 0; i < 16; ++i)
        sum[i] += 
            (int)a[i][0]*b[i][0] + (int)a[i][1]*b[i][1] +
            (int)a[i][2]*b[i][2] + (int)a[i][3]*b[i][3];
}

Unfortunately the intel CPUs until Cascade Lake don't have this instruction so there is a question to emulate this one with using of previous extension (for example AVX-512BW). So my question is: How is make this emulation maximal effective as possible?

Irmairme answered 16/6, 2021 at 9:4 Comment(0)
N
4

I think this question does not have one correct answer.

On the one hand the fast emulation of _mm512_dpbusd_epi32 with using of AVX-512BW extension may be looked as:

inline __m512i _mm512_dpbusd_epi32_bw_fast(__m512i i32, __m512i u8, __m512i i8)
{
    __m512i i16 = _mm512_maddubs_epi16(u8, i8); //possible overflow of INT16.
    __m512i _1 = _mm512_set1_epi16(1);
    return _mm512_add_epi32(i32, _mm512_madd_epi16(i16, _1));
}

This implementation uses only 3 instructions (and all of them are fast). But it can give incorrect result due to possible overflow of INT16 in _mm512_maddubs_epi16 instruction.

On the other hand correct emulation looks awful and takes 14 instructions (and some of them are notably slow):

inline __m512i _mm512_hadd_epi32(__m512i a, __m512i b)
{
    static const __m512i IDX0 = _mm512_setr_epi32(
        0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, 
        0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E);
    static const __m512i IDX1 = _mm512_setr_epi32(
        0x01, 0x03, 0x05, 0x07, 0x09, 0x0B, 0x0D, 0x0F, 
        0x11, 0x13, 0x15, 0x17, 0x19, 0x1B, 0x1D, 0x1F);
    __m512i ab0 = _mm512_permutex2var_epi32(a, IDX0, b);
    __m512i ab1 = _mm512_permutex2var_epi32(a, IDX1, b);
    return _mm512_add_epi32(ab0, ab1);
}

inline __m512i _mm512_dpbusd_epi32_bw_exact(__m512i i32, __m512i u8, __m512i i8)
{
    __m512i u8_i16lo = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 0));
    __m512i i8_i16lo = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 0));
    __m512i i32lo = _mm512_madd_epi16(u8_i16lo, i8_i16lo);
    __m512i u8_i16hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(u8, 1));
    __m512i i8_i16hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(i8, 1));
    __m512i i32hi = _mm512_madd_epi16(u8_i16hi, i8_i16hi);
    return _mm512_add_epi32(i32, _mm512_hadd_epi32(i32lo, i32hi));
}
Nagana answered 16/6, 2021 at 9:25 Comment(0)
F
2

Ermlg's inexact solution is probably quite practical, but here is a correct emulation that is slightly faster than their correct emulation. It splits up the pmaddubsw into two pieces such that it cannot overflow. (Analysis: Even if u8 has a high 1 bit in each place and i8 is identically -2^7, the first pmaddubsw result is -2^15, representable without saturation.)

__m512i patch_mm512_dpbusd_epi32(__m512i i32, __m512i u8, __m512i i8) {
    const __m512i ones = _mm512_set1_epi16(1);
    const __m512i highest_bit = _mm512_set1_epi8(0x80);

    __m512i s1 = _mm512_maddubs_epi16(_mm512_and_si512(u8, highest_bit), i8);
    __m512i s2 = _mm512_maddubs_epi16(_mm512_andnot_si512(highest_bit, u8), i8);

    s1 = _mm512_madd_epi16(s1, ones);
    s2 = _mm512_madd_epi16(s2, ones);

    return _mm512_add_epi32(_mm512_add_epi32(s1, s2), i32);
}

After constants are hoisted, this uses eight (fairly fast) instructions. The sequence also works with smaller vectors whenever SSE2 is available.

Fenrir answered 24/5, 2023 at 23:11 Comment(1)
Related: an i8 * i8 version of the same trick is possible, using set1_epi16(-1) for the MSB sums, or both using the same constant but subtracting instead of adding. I updated my answer on How to implement an efficient _mm256_madd_epi8? with this cool split trick.Graiae

© 2022 - 2024 — McMap. All rights reserved.