iq2_s
This commit is contained in:
parent
592618656a
commit
dcfee06594
1 changed files with 92 additions and 0 deletions
|
@ -9777,6 +9777,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = 0.125f * hsum_float_8(accumf);
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
};
|
||||||
|
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||||
|
const __m128i m1 = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
|
||||||
|
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
|
||||||
|
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
|
||||||
|
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
|
||||||
|
|
||||||
|
uint64_t aux64;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint8_t * restrict qs = x[i].qs;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
memcpy(&aux64, x[i].scales, 8);
|
||||||
|
const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
|
||||||
|
const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
|
||||||
|
const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
|
||||||
|
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
|
||||||
|
iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
|
||||||
|
iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
|
||||||
|
iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
|
||||||
|
iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
|
||||||
|
qs += 8;
|
||||||
|
|
||||||
|
__m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
|
||||||
|
__m128i aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
|
||||||
|
|
||||||
|
aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
|
||||||
|
aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
|
||||||
|
|
||||||
|
signs += 4;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue