iq3_s before sllv
This commit is contained in:
parent
eccc609efa
commit
39e816e54e
1 changed files with 105 additions and 0 deletions
105
ggml-quants.c
105
ggml-quants.c
|
@ -10784,6 +10784,111 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
||||||
|
|
||||||
*s = hsum_float_8(accumf);
|
*s = hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#elif defined(__AVX__)
|
||||||
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||||
|
};
|
||||||
|
|
||||||
|
const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
|
||||||
|
const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
|
||||||
|
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
|
||||||
|
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
|
||||||
|
|
||||||
|
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
||||||
|
const __m256i idx_mask = _mm256_set1_epi32(256);
|
||||||
|
|
||||||
|
typedef union {
|
||||||
|
__m128i vec_0[2];
|
||||||
|
__m128i vec_1[2];
|
||||||
|
uint32_t index[16];
|
||||||
|
} index_t;
|
||||||
|
|
||||||
|
index_t idx;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint8_t * restrict qs = x[i].qs;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
__m128i sumi1_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi1_1 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi2_1 = _mm_setzero_si128();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
|
||||||
|
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
|
||||||
|
const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
|
||||||
|
const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
|
||||||
|
idx.vec_0[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||||
|
idx.vec_1[0] = _mm_set1_epi32(qh[ib32+0]);
|
||||||
|
idx.vec_0[1] = _mm_set1_epi32(qh[ib32+1]);
|
||||||
|
idx.vec_1[1] = _mm_set1_epi32(qh[ib32+1]);
|
||||||
|
|
||||||
|
// TODO this section
|
||||||
|
idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
|
||||||
|
idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
|
||||||
|
|
||||||
|
idx.vec_0[0] = _mm_or_si128(idx.vec_0[0], _mm_cvtepi16_epi32(idx_l_0));
|
||||||
|
idx.vec_1[0] = _mm_or_si128(idx.vec_1[0], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
|
||||||
|
idx.vec_0[1] = _mm_or_si128(idx.vec_0[1], _mm_cvtepi16_epi32(idxl_1));
|
||||||
|
idx.vec_1[1] = _mm_or_si128(idx.vec_1[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
|
||||||
|
|
||||||
|
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
|
||||||
|
const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
|
||||||
|
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
|
||||||
|
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]);
|
||||||
|
|
||||||
|
__m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
|
||||||
|
__m128i aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
|
||||||
|
const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
|
||||||
|
|
||||||
|
aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
|
||||||
|
aux128_1 = aux128_0;
|
||||||
|
aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
|
||||||
|
aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
|
||||||
|
const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
|
||||||
|
const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
|
||||||
|
const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
|
||||||
|
const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
|
||||||
|
|
||||||
|
signs += 4;
|
||||||
|
|
||||||
|
const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
|
||||||
|
const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
|
||||||
|
const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
|
||||||
|
const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
|
||||||
|
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
|
||||||
|
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
|
||||||
|
const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
|
||||||
|
const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
|
||||||
|
const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
|
||||||
|
sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
|
||||||
|
sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
|
||||||
|
sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
|
||||||
|
sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(accumf);
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue