diff --git a/ggml.c b/ggml.c index 096ccacfb..a49be8858 100644 --- a/ggml.c +++ b/ggml.c @@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // -#if __AVX__ || __AVX2__ || __AVX512F__ +#if __AVX__ || __AVX2__ || __AVX512F__ || __SSE3__ // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors @@ -485,6 +485,16 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { return _mm_madd_epi16(ones, dot); } +// horizontally add 4 floats +static inline float hsum_float_4(const __m128 x) { + __m128 res =_mm_hadd_ps(x, x); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif + +#if __AVX__ || __AVX2__ || __AVX512F__ // horizontally add 8 floats static inline float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -2129,6 +2139,40 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } *s = hsum_float_8(acc); +#elif defined(__SSE3__) + // Initialize accumulator with zeros + __m128 acc = _mm_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m128 d = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) ); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx = _mm_and_si128(lowMask, tmp); + __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + + bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + + // Apply the scale, and accumulate + acc = _mm_add_ps(_mm_mul_ps( d, p0 ), acc); + acc = _mm_add_ps(_mm_mul_ps( d, p1 ), acc); + } + + *s = hsum_float_4(acc); #else // scalar float sumf = 0.0;