k_quants: WIP super-blocks with 64 weights

Q5_K scalar and AVX2 works, and with that all
k_quants are done on AVX2 and scalar
This commit is contained in:
Iwan Kawrakow 2023-06-21 18:28:40 +03:00
parent 2b2ab31a89
commit bcf8c5c384

View file

@ -2487,8 +2487,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
int8_t aux8[QK_K]; uint8_t aux8[QK_K];
int16_t aux16[8]; int16_t aux16[16];
float sums [8]; float sums [8];
memset(sums, 0, 8*sizeof(float)); memset(sums, 0, 8*sizeof(float));
@ -2496,24 +2496,20 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
int8_t * restrict a = aux8; uint8_t * restrict a = aux8;
for (int l = 0; l < 32; ++l) a[l+ 0] = (int8_t)(q4[l] & 0xF); for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
for (int l = 0; l < 32; ++l) a[l+32] = (int8_t)(q4[l] >> 4); for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
for (int j = 0; j < QK_K/32; ++j) { for (int j = 0; j < QK_K/32; ++j) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]);
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
q8 += 8; a += 8; q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
q8 += 8; a += 8; q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]);
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) sums[l] += d * aux16[l];
} }
} }
for (int l = 0; l < 8; ++l) sumf += sums[l]; for (int l = 0; l < 8; ++l) sumf += sums[l];
@ -2522,6 +2518,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
} }
#endif #endif
#if QK_K == 256
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
@ -2772,6 +2769,193 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
} }
#else
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
int8x16x4_t q5bytes;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
int32_t sumi_mins = vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
#else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
#endif
}
sumf += d * sumi - dmin * sumi_mins;
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m1 = _mm256_set1_epi16(1);
const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
float summs = 0.f;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const float d1 = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
const float d2 = y[i].d * ggml_fp16_to_fp32(x[i].d[2]);
summs -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)(q5+ 0));
const uint64_t * restrict sc = (const uint64_t *)x[i].qh;
const __m128i hbits_0 = _mm_set_epi64x(sc[0] >> 1, sc[0] >> 0);
const __m128i hbits_1 = _mm_set_epi64x(sc[0] >> 3, sc[0] >> 2);
const __m128i hbits_2 = _mm_set_epi64x(sc[0] >> 5, sc[0] >> 4);
const __m128i hbits_3 = _mm_set_epi64x(sc[0] >> 7, sc[0] >> 6);
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_1, hbits_0), mone), 4);
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_3, hbits_2), mone), 4);
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i p16_0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_0, q8_0));
const __m256i p16_1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_1, q8_1));
acc = _mm256_fmadd_ps(_mm256_set1_ps(d1), _mm256_cvtepi32_ps(p16_0), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(d2), _mm256_cvtepi32_ps(p16_1), acc);
}
*s = hsum_float_8(acc) + summs;
#else
uint8_t aux8[QK_K];
int16_t aux16[16];
float sums [8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
uint8_t * restrict a = aux8;
for (int l = 0; l < 32; ++l) {
a[l+ 0] = q4[l] & 0xF;
a[l+32] = q4[l] >> 4;
}
for (int is = 0; is < 8; ++is) {
uint8_t m = 1 << is;
for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0);
}
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
for (int j = 0; j < QK_K/32; ++j) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]);
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
q8 += 16; a += 16;
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
q8 += 16; a += 16;
for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]);
}
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#endif
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {