ggml : Q8_0 unroll x2

This commit is contained in:
Georgi Gerganov 2023-04-25 22:21:57 +03:00
parent 88618ab7f5
commit 6e0f0b6ff1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

52
ggml.c
View file

@ -3079,32 +3079,50 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; ++i) {
const block_q8_0 * restrict x0 = &x[i];
const block_q8_0 * restrict y0 = &y[i];
for (int i = 0; i < nb; i += 2) {
const block_q8_0 * restrict x0 = &x[i + 0];
const block_q8_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const int8x16_t v0_0 = vld1q_s8(x0->qs);
const int8x16_t v0_1 = vld1q_s8(x0->qs + 16);
const int8x16_t x0_0 = vld1q_s8(x0->qs);
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
const int8x16_t x1_0 = vld1q_s8(x1->qs);
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
// load y
const int8x16_t v1_0 = vld1q_s8(y0->qs);
const int8x16_t v1_1 = vld1q_s8(y0->qs + 16);
const int8x16_t y0_0 = vld1q_s8(y0->qs);
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
const int8x16_t y1_0 = vld1q_s8(y1->qs);
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0, v1_0),
vdotq_s32(vdupq_n_s32(0), v0_1, v1_1))), x0->d*y0->d);
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
#else
const int16x8_t p0l = vmull_s8(vget_low_s8 (v0_0), vget_low_s8 (v1_0));
const int16x8_t p0h = vmull_s8(vget_high_s8(v0_0), vget_high_s8(v1_0));
const int16x8_t p1l = vmull_s8(vget_low_s8 (v0_1), vget_low_s8 (v1_1));
const int16x8_t p1h = vmull_s8(vget_high_s8(v0_1), vget_high_s8(v1_1));
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
const int32x4_t pl = vaddq_s32(vpaddlq_s16(p0l), vpaddlq_s16(p0h));
const int32x4_t ph = vaddq_s32(vpaddlq_s16(p1l), vpaddlq_s16(p1h));
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph), x0->d*y0->d);
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
#endif
}