This commit is contained in:
netrunnereve 2024-06-14 23:36:18 -04:00
parent 520361f318
commit 592618656a

View file

@ -10114,6 +10114,63 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
*s = 0.25f * hsum_float_8(accumf); *s = 0.25f * hsum_float_8(accumf);
#elif defined(__AVX__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[2];
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict gas = x[i].qs + QK_K/4;
const int8_t * restrict q8 = y[i].qs;
__m128i sumi1_0 = _mm_setzero_si128();
__m128i sumi1_1 = _mm_setzero_si128();
__m128i sumi2_0 = _mm_setzero_si128();
__m128i sumi2_1 = _mm_setzero_si128();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
q3 += 8;
const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
q3 += 8;
memcpy(aux32, gas, 8); gas += 8;
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
const uint16_t ls1 = aux32[0] >> 28;
const uint16_t ls2 = aux32[1] >> 28;
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
}
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
}
*s = 0.25f * hsum_float_8(accumf);
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;