iq3_xs: working scalar and AVX2 dot products
This commit is contained in:
parent
5be4e7ac4a
commit
f1255c50c0
1 changed files with 25 additions and 11 deletions
|
@ -9491,7 +9491,8 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
__m256 accumf = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const uint8_t * restrict qs = x[i].qs;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
__m256i sumi1 = _mm256_setzero_si256();
|
||||
|
@ -9499,12 +9500,24 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
||||
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
||||
q3 += 8;
|
||||
const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
||||
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
||||
q3 += 8;
|
||||
const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
|
||||
iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
|
||||
iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
|
||||
iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
|
||||
iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
|
||||
iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
|
||||
iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
|
||||
iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
|
||||
qs += 8;
|
||||
const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
|
||||
iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
|
||||
iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
|
||||
iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
|
||||
iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
|
||||
iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
|
||||
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
|
||||
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
|
||||
qs += 8;
|
||||
memcpy(aux32, gas, 8); gas += 8;
|
||||
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
|
||||
signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
||||
|
@ -9535,7 +9548,8 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
float sumf = 0.f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const uint8_t * restrict qs = x[i].qs;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
int32_t bsum = 0;
|
||||
|
@ -9544,8 +9558,8 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
||||
int32_t sumi = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32] << (8-2*l)) & 256)));
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32] << (7-2*l)) & 256)));
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
||||
|
@ -9553,7 +9567,7 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
}
|
||||
q8 += 8;
|
||||
}
|
||||
q3 += 8;
|
||||
qs += 8;
|
||||
bsum += sumi * ls;
|
||||
}
|
||||
sumf += d * bsum;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue