ggml : implement vaddvq when missing

This commit is contained in:
Georgi Gerganov 2023-04-13 18:16:35 +03:00
parent 2ae3164d29
commit 14a0b207bc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

45
ggml.c
View file

@ -492,26 +492,43 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
#endif
#if __ARM_NEON
// check if ARMv8 is not available
#if !defined(__ARM_FEATURE_QRDMX)
inline static int16_t vaddvq_s16(int16x8_t v) {
const int16x4_t v1 = vadd_s16(vget_low_s16(v), vget_high_s16(v));
return vaddv_s16(v1);
inline static uint16_t vaddvq_u8(uint8x16_t v) {
return
(uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
(uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
(uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
(uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
(uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
(uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
(uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
}
inline static uint16_t vaddvq_u16(uint16x8_t v) {
const uint16x4_t v1 = vadd_u16(vget_low_u16(v), vget_high_u16(v));
return vaddv_u16(v1);
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static uint32_t vaddvq_u16(uint16x8_t v) {
return
(uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
(uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
(uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
(uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
const int32x2_t v1 = vadd_s32(vget_low_s32(v), vget_high_s32(v));
return vaddv_s32(v1);
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
inline static float vaddvq_f32(float32x4_t v) {
const float32x2_t v1 = vadd_f32(vget_low_f32(v), vget_high_f32(v));
return vaddv_f32(v1);
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
}
#endif
@ -2313,10 +2330,10 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
const uint16x8_t pl1l = vmull_u8(vget_low_s8 (v0_1l), vget_low_u8 (v1_1l));
const uint16x8_t pl1h = vmull_u8(vget_high_s8(v0_1l), vget_high_u8(v1_1l));
const uint16x8_t ph1l = vmull_u8(vget_low_s8 (v0_1h), vget_low_u8 (v1_1h));
const uint16x8_t ph1h = vmull_u8(vget_high_s8(v0_1h), vget_high_u8(v1_1h));
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);