diff --git a/ggml.c b/ggml.c index 57bd37559..d1b4cbdca 100644 --- a/ggml.c +++ b/ggml.c @@ -2516,7 +2516,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) +#elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2542,37 +2542,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const __m256 xy = mul_sum_i8_pairs_float(bx, by); // Accumulate d0*d1*x*y +#if defined(__AVX2__) acc = _mm256_fmadd_ps( d0d1, xy, acc ); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float * d0 = &x[i].d; - const float * d1 = &y[i].d; - - summs += x[i].m * y[i].s; - - const __m256 d0v = _mm256_broadcast_ss( d0 ); - const __m256 d1v = _mm256_broadcast_ss( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - - const __m256 xy = mul_sum_i8_pairs_float(bx, by); - - // Accumulate d0*d1*x*y +#else acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif } *s = hsum_float_8(acc) + summs; @@ -3166,7 +3140,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) +#elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -3180,25 +3154,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * const __m256 q = mul_sum_i8_pairs_float(bx, by); // Multiply q with scale and accumulate +#if defined(__AVX2__) acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - // Multiply q with scale and accumulate +#else acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif } *s = hsum_float_8(acc);