k_quants: WIP super-blocks with 64 weights
Q6_K scalar and AVX2 works
This commit is contained in:
parent
d2f12ac354
commit
9fe2a2b1db
1 changed files with 246 additions and 1 deletions
247
k_quants.c
247
k_quants.c
|
@ -2184,7 +2184,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
|
|
||||||
|
@ -2453,3 +2453,248 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
|
assert(n % QK_K == 0);
|
||||||
|
|
||||||
|
const block_q6_K * restrict x = vx;
|
||||||
|
const block_q8_K * restrict y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||||
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
//const int8x16_t m32s = vdupq_n_s8(32);
|
||||||
|
|
||||||
|
const uint8x16_t mone = vdupq_n_u8(3);
|
||||||
|
|
||||||
|
int8x16x4_t q6bytes;
|
||||||
|
uint8x16x4_t q6h;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d_all = ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q6 = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const int8_t * restrict scale = x[i].scales;
|
||||||
|
|
||||||
|
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
||||||
|
const int8x16_t scales = vld1q_s8(scale);
|
||||||
|
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
||||||
|
|
||||||
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
||||||
|
vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
|
||||||
|
vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
|
||||||
|
int32_t isum_mins = vaddvq_s32(prod);
|
||||||
|
|
||||||
|
int32_t isum = 0;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
|
uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
|
||||||
|
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
|
||||||
|
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
|
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
||||||
|
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
||||||
|
uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
|
||||||
|
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
shifted = vshrq_n_u8(qhbits.val[1], 2);
|
||||||
|
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
|
||||||
|
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
|
||||||
|
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
|
||||||
|
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
|
||||||
|
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
|
||||||
|
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
|
||||||
|
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
|
||||||
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
||||||
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
||||||
|
scale += 4;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
|
||||||
|
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
|
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
||||||
|
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
shifted = vshrq_n_u8(qhbits.val[1], 4);
|
||||||
|
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
shifted = vshrq_n_u8(qhbits.val[0], 6);
|
||||||
|
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
shifted = vshrq_n_u8(qhbits.val[1], 6);
|
||||||
|
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
||||||
|
|
||||||
|
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
|
||||||
|
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
|
||||||
|
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
|
||||||
|
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
|
||||||
|
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
|
||||||
|
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
|
||||||
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
||||||
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
||||||
|
scale += 4;
|
||||||
|
|
||||||
|
//for (int l = 0; l < 4; ++l) {
|
||||||
|
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
||||||
|
// isum += vaddvq_s32(p) * *scale++;
|
||||||
|
//}
|
||||||
|
#else
|
||||||
|
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
|
||||||
|
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
//sum += isum * d_all * y[i].d;
|
||||||
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||||
|
|
||||||
|
}
|
||||||
|
*s = sum;
|
||||||
|
|
||||||
|
#elif defined z__AVX2__
|
||||||
|
|
||||||
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||||
|
const __m256i m2 = _mm256_set1_epi8(3);
|
||||||
|
const __m256i m32s = _mm256_set1_epi8(32);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
||||||
|
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
||||||
|
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
||||||
|
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
||||||
|
|
||||||
|
__m256i sumi = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
||||||
|
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
||||||
|
|
||||||
|
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
||||||
|
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
||||||
|
|
||||||
|
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
||||||
|
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
||||||
|
|
||||||
|
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
||||||
|
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
||||||
|
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
||||||
|
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
||||||
|
|
||||||
|
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
||||||
|
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
||||||
|
|
||||||
|
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||||
|
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||||
|
|
||||||
|
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
||||||
|
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
||||||
|
|
||||||
|
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
||||||
|
|
||||||
|
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int8_t aux8[QK_K];
|
||||||
|
int16_t aux16[8];
|
||||||
|
float sums [8];
|
||||||
|
int32_t aux32[8];
|
||||||
|
memset(sums, 0, 8*sizeof(float));
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint8_t * restrict q4 = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
memset(aux32, 0, 8*sizeof(int32_t));
|
||||||
|
int8_t * restrict a = aux8;
|
||||||
|
for (int l = 0; l < 16; ++l) {
|
||||||
|
a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||||
|
a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||||
|
a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||||
|
a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||||
|
}
|
||||||
|
int is = 0;
|
||||||
|
for (int j = 0; j < QK_K/16; ++j) {
|
||||||
|
int scale = x[i].scales[is++];
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||||
|
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||||
|
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
}
|
||||||
|
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
||||||
|
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
||||||
|
}
|
||||||
|
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||||
|
*s = sumf;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue