iq3_s
This commit is contained in:
parent
39e816e54e
commit
99f666c1b6
1 changed files with 31 additions and 24 deletions
|
@ -10798,12 +10798,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
|||
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
|
||||
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
|
||||
|
||||
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
||||
const __m256i idx_mask = _mm256_set1_epi32(256);
|
||||
const __m128i idx_mask = _mm_set1_epi32(256);
|
||||
|
||||
typedef union {
|
||||
__m128i vec_0[2];
|
||||
__m128i vec_1[2];
|
||||
__m128i vec[4];
|
||||
uint32_t index[16];
|
||||
} index_t;
|
||||
|
||||
|
@ -10828,24 +10826,33 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
|||
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
|
||||
const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
|
||||
const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
|
||||
idx.vec_0[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||
idx.vec_1[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||
idx.vec_0[1] = _mm_set1_epi32(qh[ib32+1]);
|
||||
idx.vec_1[1] = _mm_set1_epi32(qh[ib32+1]);
|
||||
idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||
idx.vec[1] = idx.vec[0];
|
||||
idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
|
||||
idx.vec[3] = idx.vec[2];
|
||||
|
||||
// TODO this section
|
||||
idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
|
||||
idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
|
||||
// AVX has no sllv so we have to do this
|
||||
for (int j = 0; j <= 2; j += 2) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
int32_t * curn = (int32_t *) &idx.vec[j] + k;
|
||||
*curn = *curn << (8 - k);
|
||||
}
|
||||
}
|
||||
|
||||
idx.vec_0[0] = _mm_or_si128(idx.vec_0[0], _mm_cvtepi16_epi32(idx_l_0));
|
||||
idx.vec_1[0] = _mm_or_si128(idx.vec_1[0], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
|
||||
idx.vec_0[1] = _mm_or_si128(idx.vec_0[1], _mm_cvtepi16_epi32(idxl_1));
|
||||
idx.vec_1[1] = _mm_or_si128(idx.vec_1[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
|
||||
idx.vec[0] = _mm_and_si128(idx.vec[0], idx_mask);
|
||||
idx.vec[1] = _mm_and_si128(idx.vec[1], idx_mask);
|
||||
idx.vec[2] = _mm_and_si128(idx.vec[2], idx_mask);
|
||||
idx.vec[3] = _mm_and_si128(idx.vec[3], idx_mask);
|
||||
|
||||
idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
|
||||
idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
|
||||
idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
|
||||
idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
|
||||
|
||||
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
|
||||
const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
|
||||
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
|
||||
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
|
||||
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
|
||||
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
|
||||
|
||||
__m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
|
||||
__m128i aux128_1 = aux128_0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue