Q5_K dot product for ARM_NEON

This commit is contained in:
Iwan Kawrakow 2023-05-30 12:31:21 +03:00
parent 5ca15ce155
commit a197eb50d1

View file

@ -1293,24 +1293,107 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K; const int nb = n / QK_K;
#ifdef z__ARM_NEON
GGML_ASSERT(false);
#elif defined __AVX2__
static const uint32_t kmask1 = 0x3f3f3f3f; static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f; static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303; static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint32x4_t mzero = vdupq_n_s32(0);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
int8x16x4_t q5bytes;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
int32_t sumi_mins = vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
#else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
#endif
}
sumf += d * sumi - dmin * sumi_mins;
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m4 = _mm256_set1_epi8(0xF);
const __m128i mzero = _mm_setzero_si128(); const __m128i mzero = _mm_setzero_si128();
const __m256i mone = _mm256_set1_epi8(1); const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
uint32_t utmp[4];
float summs = 0.f; float summs = 0.f;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
@ -1385,12 +1468,6 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2]; const uint8_t * mins = (const uint8_t*)&utmp[2];