iq1_m: AVX2 dot product

This commit is contained in:
Iwan Kawrakow 2024-03-23 16:38:13 +02:00
parent 64b9dfd7ff
commit a139de51b6

View file

@ -9748,7 +9748,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
#endif
}
// TODO
void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
@ -9763,6 +9762,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int nb = n / QK_K;
#if defined z__ARM_NEON
// TODO
ggml_int8x16x4_t q1b;
ggml_int8x16x4_t q8b;
@ -9807,46 +9807,73 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
*s = sumf;
#elif defined z__AVX2__
#elif defined __AVX2__
__m256 accum = _mm256_setzero_ps();
float accum1 = 0;
iq1m_scale_t scale;
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint16_t * qh = x[i].qh;
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
__m256i sumi = _mm256_setzero_si256();
int sumi1 = 0;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib = 0; ib < QK_K/32; ib += 2) {
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
qs += 8;
const __m256i q1b_1 = _mm256_set_epi64x(
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
);
const __m256i q1b_2 = _mm256_set_epi64x(
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
const int16_t ls1 = 2*((sc[ib/2] >> 0) & 7) + 1;
const int16_t ls2 = 2*((sc[ib/2] >> 3) & 7) + 1;
const int16_t ls3 = 2*((sc[ib/2] >> 6) & 7) + 1;
const int16_t ls4 = 2*((sc[ib/2] >> 9) & 7) + 1;
const __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1));
const __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3));
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
qs += 8; qh += 4;
}
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
accum1 += d * sumi1;
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.fp16));
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
}
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
#else