diff --git a/k_quants.c b/k_quants.c index 658a02235..4d524494d 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1035,7 +1035,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))}; + const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), @@ -1218,7 +1218,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri uint32_t utmp[4]; const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD const int32x4_t vzero = vdupq_n_s32(0); +#endif const uint8x16_t m0 = vdupq_n_u8(1); const uint8x16_t m1 = vshlq_n_u8(m0, 1); @@ -1265,10 +1267,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), q3h.val[0]); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), q3h.val[1]); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), q3h.val[2]); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]); + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; @@ -1293,10 +1295,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]); + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; @@ -1538,7 +1540,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[0] &= kmask1; - const int16x8_t mins = vreinterpretq_s16_u8(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); sumf -= dmin * vaddvq_s32(prod); @@ -1743,7 +1745,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint32x4_t mzero = vdupq_n_s32(0); + const uint32x4_t mzero = vdupq_n_u32(0); const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mtwo = vdupq_n_u8(2);