Fix docker build

I have been sloppy with vector reinterpret casts on ARM_NEON.
It seems clang is very forgiving in that regard.
This commit is contained in:
Iwan Kawrakow 2023-06-03 19:03:15 +03:00
parent 6ef13823b8
commit 0a71a4e6d3

View file

@ -1035,7 +1035,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))}; const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@ -1218,7 +1218,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
uint32_t utmp[4]; uint32_t utmp[4];
const uint8x16_t m3b = vdupq_n_u8(0x3); const uint8x16_t m3b = vdupq_n_u8(0x3);
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t vzero = vdupq_n_s32(0); const int32x4_t vzero = vdupq_n_s32(0);
#endif
const uint8x16_t m0 = vdupq_n_u8(1); const uint8x16_t m0 = vdupq_n_u8(1);
const uint8x16_t m1 = vshlq_n_u8(m0, 1); const uint8x16_t m1 = vshlq_n_u8(m0, 1);
@ -1265,10 +1267,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), q3h.val[0]); q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), q3h.val[1]); q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), q3h.val[2]); q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
@ -1293,10 +1295,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]); q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]); q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]); q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]); q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
@ -1538,7 +1540,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1; utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u8(vmovl_u8(vreinterpret_u8_u32(mins8))); const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod); sumf -= dmin * vaddvq_s32(prod);
@ -1743,7 +1745,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#ifdef __ARM_NEON #ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint32x4_t mzero = vdupq_n_s32(0); const uint32x4_t mzero = vdupq_n_u32(0);
const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2); const uint8x16_t mtwo = vdupq_n_u8(2);