From bcf8c5c384ed339fd8da232aa80fa6b0a7c8729b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 18:28:40 +0300 Subject: [PATCH] k_quants: WIP super-blocks with 64 weights Q5_K scalar and AVX2 works, and with that all k_quants are done on AVX2 and scalar --- k_quants.c | 212 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 14 deletions(-) diff --git a/k_quants.c b/k_quants.c index 32eea3660..86c224fc9 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2487,8 +2487,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #else - int8_t aux8[QK_K]; - int16_t aux16[8]; + uint8_t aux8[QK_K]; + int16_t aux16[16]; float sums [8]; memset(sums, 0, 8*sizeof(float)); @@ -2496,24 +2496,20 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l+32] = (int8_t)(q4[l] >> 4); + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); for (int j = 0; j < QK_K/32; ++j) { const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) sums[l] += d * aux16[l]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; @@ -2522,6 +2518,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri } #endif +#if QK_K == 256 void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2772,6 +2769,193 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #endif } +#else + +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m1 = _mm256_set1_epi16(1); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d1 = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); + const float d2 = y[i].d * ggml_fp16_to_fp32(x[i].d[2]); + summs -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)(q5+ 0)); + + const uint64_t * restrict sc = (const uint64_t *)x[i].qh; + const __m128i hbits_0 = _mm_set_epi64x(sc[0] >> 1, sc[0] >> 0); + const __m128i hbits_1 = _mm_set_epi64x(sc[0] >> 3, sc[0] >> 2); + const __m128i hbits_2 = _mm_set_epi64x(sc[0] >> 5, sc[0] >> 4); + const __m128i hbits_3 = _mm_set_epi64x(sc[0] >> 7, sc[0] >> 6); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_1, hbits_0), mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_3, hbits_2), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_1, q8_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d1), _mm256_cvtepi32_ps(p16_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d2), _mm256_cvtepi32_ps(p16_1), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0); + } + + sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + for (int j = 0; j < QK_K/32; ++j) { + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + #if QK_K == 256 void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {