iq3_xs: make scalar and AVX2 work for new version
This commit is contained in:
parent
eacff4aa81
commit
1fef4b8b68
1 changed files with 49 additions and 25 deletions
|
@ -9499,16 +9499,23 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
|
||||
#elif defined(__AVX2__)
|
||||
|
||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
static const char k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||
};
|
||||
|
||||
uint32_t aux32[2];
|
||||
static const char k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
};
|
||||
|
||||
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
|
||||
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
|
||||
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint8_t * restrict qs = x[i].qs;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
||||
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
__m256i sumi1 = _mm256_setzero_si256();
|
||||
__m256i sumi2 = _mm256_setzero_si256();
|
||||
|
@ -9533,17 +9540,23 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
|
||||
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
|
||||
qs += 8;
|
||||
memcpy(aux32, gas, 8); gas += 8;
|
||||
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
|
||||
signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
||||
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
||||
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
||||
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
||||
|
||||
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
||||
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
|
||||
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
|
||||
|
||||
aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
||||
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
|
||||
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
|
||||
|
||||
signs += 4;
|
||||
|
||||
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
||||
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
||||
const uint16_t ls1 = aux32[0] >> 28;
|
||||
const uint16_t ls2 = aux32[1] >> 28;
|
||||
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
|
||||
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
|
||||
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
||||
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
||||
sumi1 = _mm256_add_epi32(sumi1, p1);
|
||||
|
@ -9558,32 +9571,43 @@ void ggml_vec_dot_iq3_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
|
||||
#else
|
||||
|
||||
uint32_t aux32;
|
||||
|
||||
float sumf = 0.f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint8_t * restrict qs = x[i].qs;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const uint8_t * restrict signs = x[i].signs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
int32_t bsum = 0;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||||
memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
|
||||
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
|
||||
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
|
||||
int32_t sumi = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32] << (8-2*l)) & 256)));
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32] << (7-2*l)) & 256)));
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
||||
sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
||||
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
||||
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
||||
}
|
||||
q8 += 8;
|
||||
}
|
||||
qs += 8;
|
||||
bsum += sumi * ls;
|
||||
signs += 4;
|
||||
bsum += sumi * ls1;
|
||||
sumi = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
||||
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
||||
}
|
||||
q8 += 8;
|
||||
}
|
||||
qs += 8;
|
||||
signs += 4;
|
||||
bsum += sumi * ls2;
|
||||
}
|
||||
sumf += d * bsum;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue