diff --git a/k_quants.c b/k_quants.c index fa46b704c..66e0237e2 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1521,14 +1521,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri static const uint32_t kmask2 = 0x0f0f0f0f; static const uint32_t kmask3 = 0x03030303; - uint32_t utmp[4]; + uint32_t utmp[3]; #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD const uint32x4_t mzero = vdupq_n_s32(0); +#endif - int8x16x4_t q4bytes; + int8x16x2_t q4bytes; + int8x16x2_t q8bytes; float sumf = 0; @@ -1540,17 +1543,15 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; + + const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; utmp[0] &= kmask1; - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int16x8_t mins = vreinterpretq_s16_u8(vmovl_u8(vreinterpret_u8_u32(mins8))); const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); + sumf -= dmin * vaddvq_s32(prod); const uint8_t * scales = (const uint8_t *)utmp; @@ -1559,45 +1560,51 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri //int32x4_t isum = mzero; - int32_t sumi = 0; + int32_t sumi1 = 0; + int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - q4bytes.val[2] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[3] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); -#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - // This is slightly slower than the version below - //isum = vmlaq_s32(isum, vdupq_n_s32(*scales++), - // vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1])); - //isum = vmlaq_s32(isum, vdupq_n_s32(*scales++), - // vdotq_s32(vdotq_s32(mzero, q4bytes.val[2], q8bytes.val[2]), q4bytes.val[3], q8bytes.val[3])); + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q4bytes.val[2], q8bytes.val[2]), q4bytes.val[3], q8bytes.val[3])) * *scales++; + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; #else - + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q4bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q4bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; #endif } - sumf += d * sumi - dmin * sumi_mins; - //sumf += d * vaddvq_s32(isum) - dmin * sumi_mins; + sumf += d * (sumi1 + sumi2); }