diff --git a/ggml.c b/ggml.c index dd31fc9b2..772879474 100644 --- a/ggml.c +++ b/ggml.c @@ -2134,191 +2134,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float *s = sumf; } -#if __AVX512F__ && QK4_0 == 32 -static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) { - // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory: - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // | :. =_ () [] <> () Zz Yy| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa | - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // - // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers). - // We have exactly 64 nibbles, so we want to place each nibble into a separate byte. - // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function. - // Bytes 40..63 are masked when loading the data, so they are zeroed out. -#ifdef __AVX512VBMI__ - const __m512i byte_perm = _mm512_set_epi8( - 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32, - 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24, - 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12, - 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4 - ); - const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks ); - // After applying VPERMB, `permuted` looks like this: - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ -#else - const __m512i word_perm = _mm512_set_epi16( - 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12, - 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2 - ); - const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks ); - // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only, - // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and - // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB. -#endif - - // Shift every odd-numbered 16-bit group to the right by 4 bits. - const __mmask32 shift_mask = 0xaaaaaaaa; - const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 ); - // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes): - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - - // Now we just need to zero out the higher nibble in each byte, and we're done. - const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf ); - return _mm512_and_si512( low_nibble_mask, shifted ); - // The final result looks like this: - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ -} - -static inline __m512 dot_q4_0_twoblocks_avx512( - __m512 acc, - const block_q4_0 * restrict x, - const block_q4_0 * restrict y, - int i -) { - // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes - // can potentially be unaddressable, so we make sure to mask them out before the load, even though - // we don't use them at all. This might hurt the performance slightly, since the compiler is forced - // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`. - const __mmask8 load_mask = 0x1f; - const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] ); - const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] ); - - // We want to multiply the scales, so we interpret both registers as 16 32-bit floats: - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // blocks_0_float - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // blocks_1_float - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 ); - const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 ); - // We absolutely shouldn't touch the floats marked with `xx`: they contain some - // random data, which might very well underflow. At least on Intel, this leads - // to a huge penalty that can't be ignored (easily 100x or more) unless you - // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags. - // (and ggml can't assume that you do)... - const __mmask16 scale_mul_mask = 0x21; -#ifdef __clang__ - // ...however, clang decides to optimize the multiplication mask away: - // https://godbolt.org/z/P8PqdsfvW - // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask. - __m512i scales; - __asm__( - "vmulps %1, %2, %0%{%3%}" - : "=v" ( scales ) - : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask ) - ); -#else - const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float ); -#endif - const __m512i scale_perm = _mm512_set_epi32( - 5, 5, 5, 5, 5, 5, 5, 5, - 0, 0, 0, 0, 0, 0, 0, 0 - ); - const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales ); - // After VMULPS and VPERMPS, `permuted_scales` looks like this: - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C| - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - - const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 ); - const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 ); - - // Now we want to compute dot products of 4-element byte vectors and store them in - // 32-bit integers. That is (only one 4-element vector is shown for clarity): - // +----+----+----+----+ - // ... | 03 | 02 | 01 | 00 | - // +----+----+----+----+ - // bytes_0 - // +----+----+----+----+ - // ... | D | C | B | A | - // +----+----+----+----+ - // bytes_1 - // +----+----+----+----+ - // ... | H | G | F | E | - // +----+----+----+----+ - // final_res_int - // +----+----+----+----+ - // ... | A*E+B*F+C*G+D*H | - // +----+----+----+----+ - const __m512i plus_8 = _mm512_set1_epi8( 8 ); - const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 ); - -#ifdef __AVX512VNNI__ - // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch: - // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8 - // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`, - // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator, - // which means we only need 2 instructions. - const __m512i dot_init = _mm512_set1_epi32( 4 * 64 ); - const __m512i minus_8 = _mm512_set1_epi8( -8 ); - const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 ); - const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 ); -#else - // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones. - // It has the same catch as VPDPBUSDS: the left operand should be unsigned. - // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me - // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119 - const __m512i one = _mm512_set1_epi16( 1 ); - const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 ); - const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 ); - const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 ); - const __m512i final_res_int = _mm512_madd_epi16( diff, one ); -#endif - - // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate. - const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int ); - return _mm512_fmadd_ps( permuted_scales, final_res_float, acc ); -} -#endif - inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0;