diff --git a/.gitignore b/.gitignore index 631f2360b..e52d479ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,15 @@ *.o *.a +.DS_Store +.build/ .cache/ +.direnv/ +.envrc +.swiftpm +.venv .vs/ .vscode/ -.DS_Store -.build/ build/ build-em/ build-debug/ @@ -30,12 +34,9 @@ models/* arm_neon.h compile_commands.json -.envrc -.direnv/ - -.venv __pycache__ -.swiftpm zig-out/ zig-cache/ + +ppl-*.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 8eadea4fd..d7aa051da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -174,7 +174,6 @@ if (LLAMA_ALL_WARNINGS) -Wshadow -Wstrict-prototypes -Wpointer-arith - -Wno-unused-function ) set(cxx_flags -Wall diff --git a/Makefile b/Makefile index deb0d0009..d9a2d836b 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = # warnings -CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function +CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar # OS specific diff --git a/README.md b/README.md index c6f24d032..8e1945cff 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ +**Warnings** + +- `Q4_2` and `Q4_3` are still in development. Do not expect any kind of backward compatibility until they are finalize + **Hot topics:** - [Added LoRA support](https://github.com/ggerganov/llama.cpp/pull/820) diff --git a/ggml.c b/ggml.c index 13c1548fe..3b38eaad3 100644 --- a/ggml.c +++ b/ggml.c @@ -550,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) { (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15); } +inline static int16_t vaddvq_s8(int8x16_t v) { + return + (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) + + (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) + + (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) + + (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) + + (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) + + (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) + + (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) + + (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15); +} + inline static int32_t vaddvq_s16(int16x8_t v) { return (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + @@ -1535,9 +1547,8 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in } } -static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -//static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { @@ -1552,8 +1563,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .dequantize_row_q = dequantize_row_q4_1, .quantize_row_q = quantize_row_q4_1, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, - .quantize_row_q_dot = quantize_row_q4_1, - .vec_dot_q = ggml_vec_dot_q4_1, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q4_1_q8_0, }, [GGML_TYPE_Q4_2] = { .dequantize_row_q = dequantize_row_q4_2, @@ -1562,7 +1573,13 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0, }, - // TODO: GGML_TYPE_Q8_0 + [GGML_TYPE_Q8_0] = { + .dequantize_row_q = NULL, // TODO + .quantize_row_q = quantize_row_q8_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = NULL, // TODO + }, }; // For internal test use @@ -2128,191 +2145,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float *s = sumf; } -#if __AVX512F__ && QK4_0 == 32 -static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) { - // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory: - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // | :. =_ () [] <> () Zz Yy| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa | - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // - // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers). - // We have exactly 64 nibbles, so we want to place each nibble into a separate byte. - // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function. - // Bytes 40..63 are masked when loading the data, so they are zeroed out. -#ifdef __AVX512VBMI__ - const __m512i byte_perm = _mm512_set_epi8( - 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32, - 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24, - 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12, - 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4 - ); - const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks ); - // After applying VPERMB, `permuted` looks like this: - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ -#else - const __m512i word_perm = _mm512_set_epi16( - 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12, - 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2 - ); - const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks ); - // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only, - // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and - // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB. -#endif - - // Shift every odd-numbered 16-bit group to the right by 4 bits. - const __mmask32 shift_mask = 0xaaaaaaaa; - const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 ); - // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes): - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32 - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - - // Now we just need to zero out the higher nibble in each byte, and we're done. - const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf ); - return _mm512_and_si512( low_nibble_mask, shifted ); - // The final result looks like this: - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ - // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a| - // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ -} - -static inline __m512 dot_q4_0_twoblocks_avx512( - __m512 acc, - const block_q4_0 * restrict x, - const block_q4_0 * restrict y, - int i -) { - // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes - // can potentially be unaddressable, so we make sure to mask them out before the load, even though - // we don't use them at all. This might hurt the performance slightly, since the compiler is forced - // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`. - const __mmask8 load_mask = 0x1f; - const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] ); - const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] ); - - // We want to multiply the scales, so we interpret both registers as 16 32-bit floats: - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // blocks_0_float - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // blocks_1_float - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 ); - const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 ); - // We absolutely shouldn't touch the floats marked with `xx`: they contain some - // random data, which might very well underflow. At least on Intel, this leads - // to a huge penalty that can't be ignored (easily 100x or more) unless you - // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags. - // (and ggml can't assume that you do)... - const __mmask16 scale_mul_mask = 0x21; -#ifdef __clang__ - // ...however, clang decides to optimize the multiplication mask away: - // https://godbolt.org/z/P8PqdsfvW - // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask. - __m512i scales; - __asm__( - "vmulps %1, %2, %0%{%3%}" - : "=v" ( scales ) - : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask ) - ); -#else - const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float ); -#endif - const __m512i scale_perm = _mm512_set_epi32( - 5, 5, 5, 5, 5, 5, 5, 5, - 0, 0, 0, 0, 0, 0, 0, 0 - ); - const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales ); - // After VMULPS and VPERMPS, `permuted_scales` looks like this: - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 | - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C| - // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+ - - const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 ); - const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 ); - - // Now we want to compute dot products of 4-element byte vectors and store them in - // 32-bit integers. That is (only one 4-element vector is shown for clarity): - // +----+----+----+----+ - // ... | 03 | 02 | 01 | 00 | - // +----+----+----+----+ - // bytes_0 - // +----+----+----+----+ - // ... | D | C | B | A | - // +----+----+----+----+ - // bytes_1 - // +----+----+----+----+ - // ... | H | G | F | E | - // +----+----+----+----+ - // final_res_int - // +----+----+----+----+ - // ... | A*E+B*F+C*G+D*H | - // +----+----+----+----+ - const __m512i plus_8 = _mm512_set1_epi8( 8 ); - const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 ); - -#ifdef __AVX512VNNI__ - // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch: - // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8 - // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`, - // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator, - // which means we only need 2 instructions. - const __m512i dot_init = _mm512_set1_epi32( 4 * 64 ); - const __m512i minus_8 = _mm512_set1_epi8( -8 ); - const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 ); - const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 ); -#else - // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones. - // It has the same catch as VPDPBUSDS: the left operand should be unsigned. - // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me - // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119 - const __m512i one = _mm512_set1_epi16( 1 ); - const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 ); - const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 ); - const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 ); - const __m512i final_res_int = _mm512_madd_epi16( diff, one ); -#endif - - // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate. - const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int ); - return _mm512_fmadd_ps( permuted_scales, final_res_float, acc ); -} -#endif - inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0; @@ -2349,535 +2181,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t *s = sumf; } -static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK4_0; - - assert(n % QK4_0 == 0); - assert(nb % 2 == 0); - - const block_q4_0 * restrict x = vx; - const block_q4_0 * restrict y = vy; - - float sumf = 0.0; - -#if defined(__ARM_NEON) - float sum0 = 0.0f; - float sum1 = 0.0f; - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict y0 = &y[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q4_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v1_0 = vld1q_u8(y0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - const uint8x16_t v1_1 = vld1q_u8(y1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); - const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); - - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); - const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); - - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); - int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); - - p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); - p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); - - sum0 += x0->d*y0->d*vaddvq_s32(p_0); - sum1 += x1->d*y1->d*vaddvq_s32(p_1); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); - - const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); - const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); - - const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); - const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); - - const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); - const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - - sum0 += x0->d*y0->d*vaddvq_s16(p_0); - sum1 += x1->d*y1->d*vaddvq_s16(p_1); -#endif - } - - sumf = sum0 + sum1; -#elif defined(__AVX512F__) - // Initialize accumulator with zeros - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - - const int superblock_size = 16; - - const int superblock_count = nb / superblock_size; - - for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { - int i = superblock_ix * superblock_size; - - acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 ); - acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 ); - acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 ); - acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 ); - acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 ); - acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 ); - acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 ); - acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 ); - } - - // Remainders - for (int i = superblock_count * superblock_size; i < nb; i += 2) { - acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i ); - } - - // Horizontal sum of all lanes of the accumulator - sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - /* Prepare the constants we will need during execution */ - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - const __m256i offset_8 = _mm256_set1_epi16( 8 ); - -#define UNROLL_COUNT 8 - // make sure we only unroll multiples of the block count - assert(nb % UNROLL_COUNT == 0); - - // Main loop - for (int i = 0; i < nb; i+=UNROLL_COUNT) { - // This loop will be unrolled by the compiler - for (int u=0;u we now have a vector of 8 int_32t */ - __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q ); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps( xy_q ); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( scale, q, acc ); - } - } - - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); - - __m128i i32[2]; - for (int j = 0; j < 2; ++j) { - // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes - __m128i bx = bytesFromNibbles( x[i].qs + 8*j ); - __m128i by = bytesFromNibbles( y[i].qs + 8*j ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m128i off = _mm_set1_epi8( 8 ); - bx = _mm_sub_epi8( bx, off ); - by = _mm_sub_epi8( by, off ); - - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(bx, bx); - - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(by, bx); - - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - - const __m128i ones = _mm_set1_epi16(1); - i32[j] = _mm_madd_epi16(ones, dot); - } - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] )); - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); - } - - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); -#elif defined(__wasm_simd128__) - // wasm simd - float sum0 = 0.0f; - float sum1 = 0.0f; - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict y0 = &y[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q4_0 * restrict y1 = &y[i + 1]; - - const v128_t m4b = wasm_u8x16_splat(0xf); - const v128_t s8b = wasm_i8x16_splat(0x8); - - const v128_t v0_0 = wasm_v128_load(x0->qs); - const v128_t v0_1 = wasm_v128_load(y0->qs); - const v128_t v1_0 = wasm_v128_load(x1->qs); - const v128_t v1_1 = wasm_v128_load(y1->qs); - - // 4-bit -> 8-bit - const v128_t v0_0l = wasm_v128_and(v0_0, m4b); - const v128_t v1_0l = wasm_v128_and(v1_0, m4b); - - const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); - const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4); - - const v128_t v0_1l = wasm_v128_and(v0_1, m4b); - const v128_t v1_1l = wasm_v128_and(v1_1, m4b); - - const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); - const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4); - - // sub 8 - const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); - const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b); - - const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); - const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b); - - const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); - const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b); - - const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); - const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b); - - // dot product into int16x8_t - const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls)); - const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls)); - - const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs)); - const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs)); - - const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls)); - const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls)); - - const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs)); - const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs)); - - const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h); - const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h); - - const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h); - const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h); - - const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); - const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); - - sum0 += x0->d * y0->d * ( - wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + - wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + - wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + - wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); - sum1 += x1->d * y1->d * ( - wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + - wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + - wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + - wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7)); - } - - sumf = sum0 + sum1; -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d0 = x[i].d; - const float d1 = y[i].d; - - const uint8_t * restrict p0 = x[i].qs; - const uint8_t * restrict p1 = y[i].qs; - - int sumi = 0; - for (int j = 0; j < QK4_0/2; j++) { - const uint8_t v0 = p0[j]; - const uint8_t v1 = p1[j]; - - const int i0 = (v0 & 0xf) - 8; - const int i1 = (v0 >> 4) - 8; - - const int i2 = (v1 & 0xf) - 8; - const int i3 = (v1 >> 4) - 8; - - sumi += i0*i2 + i1*i3; - } - sumf += d0 * d1 * sumi; - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK4_1; - - const block_q4_1 * restrict x = vx; - const block_q4_1 * restrict y = vy; - - float sumf = 0.0; - -#if defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - // Accumulator for constant offsets - float acc_offset = 0.0f; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float * d0 = &x[i].d; - const float * d1 = &y[i].d; - - const float * m0 = &x[i].m; - const float * m1 = &y[i].m; - - const __m256 d0v = _mm256_broadcast_ss( d0 ); - const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); - const __m256 m1v = _mm256_broadcast_ss( m1 ); - - // Compute combined scale for the block - const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); - - // Compute cross scales for the block - const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); - const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); - const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - __m256i bx = bytesFromNibbles( x[i].qs ); - __m256i by = bytesFromNibbles( y[i].qs ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. - - // Sign-extend first 16 signed bytes into int16_t - __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); - __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - // Compute products of int16_t integers, add pairwise - __m256i i32 = _mm256_madd_epi16( x16, y16 ); - - // Sign-extend last 16 signed bytes into int16_t vectors - __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); - __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - // Accumulate products of int16_t integers - i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); - - // compute sums of unsigned bytes in bx, by in blocks of 8. - // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, - // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. - // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] - __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); - __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); - __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); - __m256 sums = _mm256_cvtepi32_ps( sumsi ); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps( i32 ); - // Apply the scale, and accumulate - // acc += d0*d1*x*y + d0*m1*x + d1*m0*y - acc = _mm256_fmadd_ps( scale_01, p, acc ); - acc = _mm256_fmadd_ps( cross_scales, sums, acc ); - // acc_offset += m0*m1 (for each entry in the block) - acc_offset += (*m0)*(*m1); - } - - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1; -#elif defined(__ARM_NEON) - float sum00 = 0.0f; - float sum01 = 0.0f; - float sum10 = 0.0f; - float sum11 = 0.0f; - - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict y0 = &y[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q4_1 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v1_0 = vld1q_u8(y0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - const uint8x16_t v1_1 = vld1q_u8(y1->qs); - - // 4-bit -> 8-bit - const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); - const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); - const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - - const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); - const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); - const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); - const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); - - sum00 += x0->m*y0->m; - sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h)); - sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h)); - - sum00 += x1->m*y1->m; - sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h)); - sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h)); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l); - uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l); - - p_0 = vdotq_u32(p_0, v0_0h, v1_0h); - p_1 = vdotq_u32(p_1, v0_1h, v1_1h); - - sum11 += x0->d*y0->d*vaddvq_u32(p_0); - sum11 += x1->d*y1->d*vaddvq_u32(p_1); -#else - const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); - const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); - const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); - const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); - - const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); - const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); - const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); - const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); - - const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h); - const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h); - - const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h); - const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h); - - const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0); - const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1); - - sum11 += x0->d*y0->d*vaddvq_u16(p_0); - sum11 += x1->d*y1->d*vaddvq_u16(p_1); -#endif - } - - sumf = QK4_1*sum00 + sum01 + sum10 + sum11; -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d0 = x[i].d; - const float d1 = y[i].d; - - const float m0 = x[i].m; - const float m1 = y[i].m; - - const uint8_t * restrict p0 = x[i].qs; - const uint8_t * restrict p1 = y[i].qs; - - for (int j = 0; j < QK4_1/2; j++) { - const uint8_t v0 = p0[j]; - const uint8_t v1 = p1[j]; - - const float f0 = d0*(v0 & 0xf) + m0; - const float f1 = d0*(v0 >> 4) + m0; - - const float f2 = d1*(v1 & 0xf) + m1; - const float f3 = d1*(v1 >> 4) + m1; - - sumf += f0*f2 + f1*f3; - } - } -#endif - - *s = sumf; -} - static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_0; @@ -3074,6 +2377,175 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * *s = sumf; } +static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + + const block_q4_1 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + float sumf = 0.0; + + // TODO: add AVX / WASM SIMD / etc +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // interleave + const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); + const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); + const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); + const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); + + const int16x8_t s0i = vaddq_s16( + vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), + vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); + + const int16x8_t s1i = vaddq_s16( + vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), + vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); +#endif + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + const float * d0 = &x[i].d; + const float * d1 = &y[i].d; + const float * m0 = &x[i].m; + + const __m256 d0v = _mm256_broadcast_ss( d0 ); + const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 m0v = _mm256_broadcast_ss( m0 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytesFromNibbles( x[i].qs ); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8( bx, bx ); + + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8( by, bx ); + + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16( ax, sy ); + const __m256i ones = _mm256_set1_epi16( 1 ); + const __m256i xy_q = _mm256_madd_epi16( ones, dot ); + + // Convert to vector of 8 int32_t to 8 floats + const __m256 xy = _mm256_cvtepi32_ps( xy_q ); + + // Accumulate d0*d1*x*y + acc = _mm256_fmadd_ps( d0d1, xy, acc ); + + // Compute sum of y values + const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); + const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); + + // Accumulate d1*m0*y + acc = _mm256_fmadd_ps( d1m0, ysum, acc ); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = x[i].d; + const float m0 = x[i].m; + const float d1 = y[i].d; + + const uint8_t * restrict p0 = x[i].qs; + const int8_t * restrict p1 = y[i].qs; + + // TODO: this is very slow .. + for (int j = 0; j < QK8_0/2; j++) { + const uint8_t v0 = p0[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*p1[2*j + 0]; + const float f3 = d1*p1[2*j + 1]; + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int nb = n / QK8_0; @@ -11064,7 +10536,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; - } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { + } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1;