diff --git a/k_quants.c b/k_quants.c index 4b9208b62..900552a7a 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1009,7 +1009,101 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef __AVX2__ +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + + int8x16x4_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; + const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 2), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 2), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_1.val[0])) * aux[is+0]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_1.val[1])) * aux[is+1]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_1.val[2])) * aux[is+2]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_1.val[3])) * aux[is+3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3]; +#endif + is += 4; + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 4), m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 4), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 6), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 6), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_2.val[0])) * aux[is+0]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_2.val[1])) * aux[is+1]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_2.val[2])) * aux[is+2]; + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_2.val[3])) * aux[is+3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3]; +#endif + is += 4; + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); const __m128i m4 = _mm_set1_epi8(0xF);