From c3b2029698b4d0f629a890a67a9b0e93e55e2398 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 27 Jan 2024 15:48:20 +0200 Subject: [PATCH] iq3_xxs: scalar and AVX2 dot products --- ggml-quants.c | 58 ++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index fe0bf73b5..ce5343949 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -8622,7 +8622,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); - const block_iq2_xxs * restrict x = vx; + const block_iq3_xxs * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; @@ -8672,32 +8672,36 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; + uint32_t aux32[2]; __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); sumi1 = _mm256_add_epi32(sumi1, p1); @@ -8708,37 +8712,39 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res } - *s = 0.125f * hsum_float_8(accumf); + *s = 0.25f * hsum_float_8(accumf); #else - uint32_t aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; + uint32_t aux32; float sumf = 0.f; for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; int32_t bsum = 0; for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(aux32, q2, 2*sizeof(uint32_t)); - q2 += 4; - const uint32_t ls = 2*(aux32[1] >> 28) + 1; + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; int32_t sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); } q8 += 8; } + q3 += 8; bsum += sumi * ls; } sumf += d * bsum; } - *s = 0.125f * sumf; + *s = 0.25f * sumf; #endif }