A very slightly faster ARM_NEON Q3_K dot

This commit is contained in:
Iwan Kawrakow 2023-05-31 08:46:30 +03:00
parent 13264fa067
commit 4faa040c20

View file

@ -904,6 +904,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m1 = vshlq_n_u8(m0, 1); const uint8x16_t m1 = vshlq_n_u8(m0, 1);
const uint8x16_t m2 = vshlq_n_u8(m0, 2); const uint8x16_t m2 = vshlq_n_u8(m0, 2);
const uint8x16_t m3 = vshlq_n_u8(m0, 3); const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;
int8x16x4_t q3bytes; int8x16x4_t q3bytes;
@ -930,7 +931,8 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
const int8_t * scale = (const int8_t *)utmp; int8_t * scale = (int8_t *)utmp;
for (int j = 0; j < 16; ++j) scale[j] -= m32;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
@ -949,10 +951,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]);
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * (scale[0] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * (scale[1] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * (scale[2] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * (scale[3] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
#else #else
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
@ -962,10 +964,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
isum += vaddvq_s16(p0) * (scale[0] - 32) + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
vaddvq_s16(p1) * (scale[1] - 32) +
vaddvq_s16(p2) * (scale[2] - 32) +
vaddvq_s16(p3) * (scale[3] - 32);
#endif #endif
scale += 4; scale += 4;
@ -974,19 +973,16 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]); q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]);
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]); q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]);
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]); q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]);
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]);
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * (scale[0] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * (scale[1] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * (scale[2] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * (scale[3] - 32); isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
#else #else
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
@ -996,13 +992,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
isum += vaddvq_s16(p0) * (scale[0] - 32) + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
vaddvq_s16(p1) * (scale[1] - 32) +
vaddvq_s16(p2) * (scale[2] - 32) +
vaddvq_s16(p3) * (scale[3] - 32);
#endif #endif
scale += 4; scale += 4;
if (j == 0) {
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
}
} }
sum += d * isum; sum += d * isum;