k_quants: WIP super-blocks with 64 weights
Q4_K scalar and AVX2 works
This commit is contained in:
parent
9fe2a2b1db
commit
1f6195c2f2
1 changed files with 171 additions and 1 deletions
172
k_quants.c
172
k_quants.c
|
@ -1702,6 +1702,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
|
|
||||||
|
@ -1932,6 +1933,175 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
|
assert(n % QK_K == 0);
|
||||||
|
|
||||||
|
const block_q4_K * restrict x = vx;
|
||||||
|
const block_q8_K * restrict y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int8x16x2_t q4bytes;
|
||||||
|
int8x16x2_t q8bytes;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||||
|
|
||||||
|
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
||||||
|
|
||||||
|
memcpy(utmp, x[i].scales, 12);
|
||||||
|
|
||||||
|
const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
|
||||||
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||||
|
utmp[0] &= kmask1;
|
||||||
|
|
||||||
|
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
||||||
|
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
||||||
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
||||||
|
sumf -= dmin * vaddvq_s32(prod);
|
||||||
|
|
||||||
|
const uint8_t * scales = (const uint8_t *)utmp;
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
//int32x4_t isum = mzero;
|
||||||
|
|
||||||
|
int32_t sumi1 = 0;
|
||||||
|
int32_t sumi2 = 0;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
|
||||||
|
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
|
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
|
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
|
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
||||||
|
|
||||||
|
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
|
|
||||||
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
||||||
|
#else
|
||||||
|
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
||||||
|
|
||||||
|
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined z__AVX2__
|
||||||
|
|
||||||
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||||
|
const __m256i m1 = _mm256_set1_epi16(1);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
float summs = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
|
||||||
|
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
||||||
|
const __m256i q4l = _mm256_and_si256(q4bits, m4);
|
||||||
|
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
||||||
|
|
||||||
|
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
||||||
|
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
||||||
|
|
||||||
|
const __m256i p32l = _mm256_madd_epi16(m1, p16l);
|
||||||
|
const __m256i p32h = _mm256_madd_epi16(m1, p16h);
|
||||||
|
|
||||||
|
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc);
|
||||||
|
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc) - summs;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int8_t aux8[QK_K];
|
||||||
|
int16_t aux16[8];
|
||||||
|
float sums [8];
|
||||||
|
memset(sums, 0, 8*sizeof(float));
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
int8_t * restrict a = aux8;
|
||||||
|
for (int l = 0; l < 32; ++l) a[l+ 0] = (int8_t)(q4[l] & 0xF);
|
||||||
|
for (int l = 0; l < 32; ++l) a[l+32] = (int8_t)(q4[l] >> 4);
|
||||||
|
|
||||||
|
sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) +
|
||||||
|
ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/32; ++j) {
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]);
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
||||||
|
q8 += 8; a += 8;
|
||||||
|
for (int l = 0; l < 8; ++l) sums[l] += d * aux16[l];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||||
|
*s = sumf;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
assert(n % QK_K == 0);
|
assert(n % QK_K == 0);
|
||||||
|
@ -2601,7 +2771,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
}
|
}
|
||||||
*s = sum;
|
*s = sum;
|
||||||
|
|
||||||
#elif defined z__AVX2__
|
#elif defined __AVX2__
|
||||||
|
|
||||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||||
const __m256i m2 = _mm256_set1_epi8(3);
|
const __m256i m2 = _mm256_set1_epi8(3);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue