iq2_xxs
This commit is contained in:
parent
75370d779e
commit
65765c9ea9
1 changed files with 56 additions and 1 deletions
|
@ -8819,7 +8819,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
|
||||
#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
|
||||
static const int8_t keven_signs_q2xs[1024] = {
|
||||
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||
|
@ -8952,6 +8952,61 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
|||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
#elif defined(__AVX__)
|
||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint16_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
__m128i sumi1_0 = _mm_setzero_si128();
|
||||
__m128i sumi1_1 = _mm_setzero_si128();
|
||||
__m128i sumi2_0 = _mm_setzero_si128();
|
||||
__m128i sumi2_1 = _mm_setzero_si128();
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||
const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
||||
const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]]);
|
||||
const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
||||
const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
|
||||
const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||
const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
|
||||
const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
|
||||
const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
|
||||
const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
|
||||
const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
|
||||
const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
|
||||
const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
|
||||
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||
const uint16_t ls1 = aux32[1] >> 28;
|
||||
const uint16_t ls2 = aux32[3] >> 28;
|
||||
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
|
||||
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
|
||||
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
|
||||
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
|
||||
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||
}
|
||||
|
||||
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
vector float vsumf0 = vec_splats(0.0f);
|
||||
vector float vsumf1 = vec_splats(0.0f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue