k_quants: WIP super-blocks with 64 weights
Q2_K scalar and AVX2 works. Q2_K is way too slow (it is actually slower than the scalar implementation)
This commit is contained in:
parent
1f6195c2f2
commit
aebd5471e9
1 changed files with 183 additions and 1 deletions
184
k_quants.c
184
k_quants.c
|
@ -1203,6 +1203,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
||||
const block_q2_K * restrict x = vx;
|
||||
|
@ -1402,6 +1403,187 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
#endif
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
||||
const block_q2_K * restrict x = vx;
|
||||
const block_q8_K * restrict y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
|
||||
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
||||
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
||||
const int32x4_t vzero = vdupq_n_s32(0);
|
||||
|
||||
int8x16x2_t q2bytes;
|
||||
uint8_t aux[16];
|
||||
|
||||
float sum = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
const uint8_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
const uint8_t * restrict sc = x[i].scales;
|
||||
|
||||
const uint8x16_t mins_and_scales = vld1q_u8(sc);
|
||||
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
|
||||
vst1q_u8(aux, scales);
|
||||
|
||||
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
||||
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
||||
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
||||
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
||||
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
||||
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
||||
vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
|
||||
sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
|
||||
|
||||
int isum = 0;
|
||||
int is = 0;
|
||||
|
||||
// We use this macro instead of a function call because for some reason
|
||||
// the code runs 2-3% slower, even if the function is declared inline
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
||||
#else
|
||||
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||
{\
|
||||
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
||||
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
||||
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
||||
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
||||
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
||||
}
|
||||
#endif
|
||||
|
||||
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
||||
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
|
||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
||||
MULTIPLY_ACCUM_WITH_SCALE((index));
|
||||
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
|
||||
|
||||
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
||||
MULTIPLY_ACCUM_WITH_SCALE(0);
|
||||
|
||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
||||
|
||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
||||
|
||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
||||
|
||||
is += 8;
|
||||
}
|
||||
sum += d * isum;
|
||||
|
||||
}
|
||||
|
||||
*s = sum;
|
||||
|
||||
#elif defined z__AVX2__
|
||||
|
||||
const __m256i m3 = _mm256_set1_epi8(3);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
uint32_t ud, um;
|
||||
const uint8_t * restrict db = (const uint8_t *)&ud;
|
||||
const uint8_t * restrict mb = (const uint8_t *)&um;
|
||||
|
||||
float summs = 0;
|
||||
|
||||
// TODO: optimize this
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
const uint8_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
||||
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
||||
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
||||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||
const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||
|
||||
const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
||||
const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
||||
|
||||
const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
|
||||
const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
|
||||
const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
|
||||
const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
||||
#else
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
int isum[4];
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const uint8_t * q2 = x[i].qs;
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * sc = x[i].scales;
|
||||
|
||||
int summs = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
summs += y[i].bsums[j] * (sc[j] >> 4);
|
||||
}
|
||||
|
||||
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
isum[0] = isum[1] = isum[2] = isum[3] = 0;
|
||||
for (int l = 0; l < 16; ++l) {
|
||||
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
|
||||
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
|
||||
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
|
||||
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
|
||||
}
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
isum[l] *= (sc[l] & 0xF);
|
||||
}
|
||||
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
|
||||
}
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % QK_K == 0);
|
||||
|
||||
|
@ -2029,7 +2211,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
*s = sumf;
|
||||
|
||||
#elif defined z__AVX2__
|
||||
#elif defined __AVX2__
|
||||
|
||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||
const __m256i m1 = _mm256_set1_epi16(1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue