ggml : Q8_0 unroll x2
This commit is contained in:
parent
88618ab7f5
commit
6e0f0b6ff1
1 changed files with 35 additions and 17 deletions
52
ggml.c
52
ggml.c
|
@ -3079,32 +3079,50 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
|
|||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const block_q8_0 * restrict x0 = &x[i];
|
||||
const block_q8_0 * restrict y0 = &y[i];
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
const int8x16_t v0_0 = vld1q_s8(x0->qs);
|
||||
const int8x16_t v0_1 = vld1q_s8(x0->qs + 16);
|
||||
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
||||
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
||||
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
||||
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0 = vld1q_s8(y0->qs);
|
||||
const int8x16_t v1_1 = vld1q_s8(y0->qs + 16);
|
||||
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
||||
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
||||
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
||||
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
||||
vdotq_s32(vdupq_n_s32(0), v0_0, v1_0),
|
||||
vdotq_s32(vdupq_n_s32(0), v0_1, v1_1))), x0->d*y0->d);
|
||||
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
||||
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
|
||||
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
||||
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
||||
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
|
||||
|
||||
#else
|
||||
const int16x8_t p0l = vmull_s8(vget_low_s8 (v0_0), vget_low_s8 (v1_0));
|
||||
const int16x8_t p0h = vmull_s8(vget_high_s8(v0_0), vget_high_s8(v1_0));
|
||||
const int16x8_t p1l = vmull_s8(vget_low_s8 (v0_1), vget_low_s8 (v1_1));
|
||||
const int16x8_t p1h = vmull_s8(vget_high_s8(v0_1), vget_high_s8(v1_1));
|
||||
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
||||
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
||||
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
const int32x4_t pl = vaddq_s32(vpaddlq_s16(p0l), vpaddlq_s16(p0h));
|
||||
const int32x4_t ph = vaddq_s32(vpaddlq_s16(p1l), vpaddlq_s16(p1h));
|
||||
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
||||
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
||||
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl), x0->d*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph), x0->d*y0->d);
|
||||
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue