iq1s_blocks16: scalar and AVX2 dot products

This commit is contained in:
Iwan Kawrakow 2024-03-08 15:50:36 +02:00
parent c55e66f997
commit 864a5c2ce4

View file

@ -3449,39 +3449,22 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
assert(k % QK_K == 0);
const int nb = k / QK_K;
float db[4];
uint16_t idx[4];
//const int8_t * grid[4];
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * sc = (const uint8_t *)x[i].qh; //x[i].scales;
const uint8_t * qs = x[i].qs;
const uint8_t * qs = x[i].qs;
const uint16_t * qh = x[i].qh;
for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
//grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
//grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
//grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
//grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
db[0] = d * (2*(sc[0] & 7) + 1);
db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
db[2] = d * (2*(sc[1] & 7) + 1);
db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
for (int ib = 0; ib < QK_K/32; ++ib) {
const float dl = d * (2*(qh[ib] >> 12) + 1);
for (int l = 0; l < 4; ++l) {
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
for (int j = 0; j < 8; ++j) {
//y[j] = db[l] * grid[l][j];
y[j] = db[l] * grid[j];
y[j] = dl * grid[j];
}
y += 8;
}
qs += 4;
sc += 2;
}
}
}
@ -9645,55 +9628,31 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
*s = sumf;
// TODO: implement for QK_K = 64
#elif defined __AVX2__ && QK_K == 256
const __m128i m8 = _mm_set1_epi8(0x08);
const __m128i m7 = _mm_set1_epi8(0x07);
const __m128i m1 = _mm_set1_epi8(0x01);
const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
const __m128i shuffle_s[4] = {
_mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
_mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
_mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
_mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
};
uint64_t aux64;
typedef union m256i_uint16 {
__m256i reg;
uint16_t s[16];
} m256i_uint16_t;
m256i_uint16_t v_gindex;
#elif defined __AVX2__
__m256 accum = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * sc = (const uint8_t *)x[i].qh; //x[i].scales;
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint16_t * qh = x[i].qh;
__m256i sumi = _mm256_setzero_si256();
for (int i128 = 0; i128 < QK_K/128; ++i128) {
const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
memcpy(&aux64, sc, 8); sc += 8;
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
for (int ib = 0; ib < QK_K/32; ib += 2) {
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | (((qh[ib+0] >> 9) & 7) << 8)], iq1s_grid[qs[2] | (((qh[ib+0] >> 6) & 7) << 8)],
iq1s_grid[qs[1] | (((qh[ib+0] >> 3) & 7) << 8)], iq1s_grid[qs[0] | (((qh[ib+0] >> 0) & 7) << 8)]);
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | (((qh[ib+1] >> 9) & 7) << 8)], iq1s_grid[qs[6] | (((qh[ib+1] >> 6) & 7) << 8)],
iq1s_grid[qs[5] | (((qh[ib+1] >> 3) & 7) << 8)], iq1s_grid[qs[4] | (((qh[ib+1] >> 0) & 7) << 8)]);
qs += 8;
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
for (int i32 = 0; i32 < 4; ++i32) {
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
const __m256i dot = mul_add_epi8(q1b, q8b);
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
const __m256i p = _mm256_madd_epi16(s16, dot);
sumi = _mm256_add_epi32(sumi, p);
}
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*(qh[ib+0] >> 12) + 1));
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*(qh[ib+1] >> 12) + 1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
}
accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
@ -9704,35 +9663,26 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
#else
int db[4];
uint16_t idx[4];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
for (int i = 0; i < nb; i++) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * sc = x[i].scales;
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint16_t * qh = x[i].qh;
int sumi = 0;
for (int i32 = 0; i32 < QK_K/32; ++i32) {
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
db[0] = (2*(sc[0] & 7) + 1);
db[1] = (2*((sc[0] >> 4) & 7) + 1);
db[2] = (2*(sc[1] & 7) + 1);
db[3] = (2*((sc[1] >> 4) & 7) + 1);
for (int ib = 0; ib < QK_K/32; ++ib) {
const int ls = 2*(qh[ib] >> 12) + 1;
int lsum = 0;
for (int l = 0; l < 4; ++l) {
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
int suml = 0;
for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
sumi += db[l] * suml;
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
for (int j = 0; j < 8; ++j) {
lsum += q8[j] * grid[j];
}
q8 += 8;
}
sumi += ls * lsum;
qs += 4;
sc += 2;
}
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;