Q8_0: unbreak AVX
This commit is contained in:
parent
955ef9a5d5
commit
708540712d
1 changed files with 10 additions and 3 deletions
13
ggml.c
13
ggml.c
|
@ -459,6 +459,14 @@ static inline float hsum_float_8(const __m256 x) {
|
|||
return _mm_cvtss_f32(res);
|
||||
}
|
||||
|
||||
// horizontally add 4 int32_t
|
||||
static inline int hsum_i32_4(const __m128i a) {
|
||||
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
|
||||
const __m128i sum64 = _mm_add_epi32(hi64, a);
|
||||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
||||
}
|
||||
|
||||
// horizontally add 8 int32_t
|
||||
static inline int hsum_i32_8(const __m256i a) {
|
||||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
||||
|
@ -1381,7 +1389,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
y[i].s1 = d * sum1;
|
||||
}
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
// TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x );
|
||||
|
@ -1428,7 +1435,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
|
||||
#if defined(__AVX2__)
|
||||
// Compute the sum of the quants and set y[i].s
|
||||
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
||||
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
||||
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
||||
|
||||
|
@ -1459,8 +1465,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
|
|||
|
||||
// Compute the sum of the quants and set y[i].s
|
||||
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
||||
y[i].s0 = d * hsum_i32_4(s0);
|
||||
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
||||
y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0));
|
||||
y[i].s1 = d * hsum_i32_4(s1);
|
||||
|
||||
// Convert int32 to int16
|
||||
ni0 = _mm_packs_epi32( ni0, ni1 );
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue