ggml : fix AVX dot products

This commit is contained in:
Georgi Gerganov 2023-05-19 19:13:09 +03:00
parent 8b7132972d
commit a44349753a

44
ggml.c
View file

@ -2364,7 +2364,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -2388,8 +2388,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@ -2406,8 +2406,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
#endif
}
@ -2421,9 +2421,9 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
// Main loop
for (int i = 0; i < nb; ++i) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
const float d1 = GGML_FP16_TO_FP32(y[i].d);
const float d1 = y[i].d;
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
const __m256 d0v = _mm256_set1_ps( d0 );
const __m256 d1v = _mm256_set1_ps( d1 );
@ -2460,7 +2460,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i]).d*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
sumf += (GGML_FP16_TO_FP32(x[i]).d*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;
@ -2635,7 +2635,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i].d);
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
@ -2659,7 +2659,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i].d);
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@ -2741,8 +2741,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const uint8x16_t m4b = vdupq_n_u8(0x0F);
summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
// extract the 5th bit via lookup table ((b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
@ -2787,10 +2787,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@ -2807,8 +2807,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
#endif
}
@ -2826,7 +2826,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const block_q5_1 * restrict x0 = &x[i];
const block_q8_1 * restrict y0 = &y[i];
summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
const v128_t m4b = wasm_i8x16_splat(0x0F);
@ -2875,7 +2875,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)));
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d));
}
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@ -2890,14 +2890,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
bx = _mm256_or_si256(bx, bxhi);
const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d));
const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
@ -2931,7 +2931,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);
const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d));
const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
@ -2960,7 +2960,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;