iq1_m: AVX2 dot product
This commit is contained in:
parent
64b9dfd7ff
commit
a139de51b6
1 changed files with 50 additions and 23 deletions
|
@ -9748,7 +9748,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|||
#endif
|
||||
}
|
||||
|
||||
// TODO
|
||||
void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
|
@ -9763,6 +9762,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
|||
const int nb = n / QK_K;
|
||||
|
||||
#if defined z__ARM_NEON
|
||||
// TODO
|
||||
|
||||
ggml_int8x16x4_t q1b;
|
||||
ggml_int8x16x4_t q8b;
|
||||
|
@ -9807,46 +9807,73 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
|||
|
||||
*s = sumf;
|
||||
|
||||
#elif defined z__AVX2__
|
||||
#elif defined __AVX2__
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
float accum1 = 0;
|
||||
iq1m_scale_t scale;
|
||||
|
||||
__m256 accum1 = _mm256_setzero_ps();
|
||||
__m256 accum2 = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint16_t * qh = x[i].qh;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
int sumi1 = 0;
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
|
||||
__m256i sumi1 = _mm256_setzero_si256();
|
||||
__m256i sumi2 = _mm256_setzero_si256();
|
||||
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
|
||||
iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
|
||||
const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
|
||||
iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
|
||||
qs += 8;
|
||||
const __m256i q1b_1 = _mm256_set_epi64x(
|
||||
iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
|
||||
iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
|
||||
);
|
||||
const __m256i q1b_2 = _mm256_set_epi64x(
|
||||
iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
|
||||
iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
|
||||
);
|
||||
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
||||
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
||||
const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
||||
const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
||||
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
|
||||
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
|
||||
sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
|
||||
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
|
||||
const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||
const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
|
||||
qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
|
||||
|
||||
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
|
||||
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
|
||||
const int16_t ls1 = 2*((sc[ib/2] >> 0) & 7) + 1;
|
||||
const int16_t ls2 = 2*((sc[ib/2] >> 3) & 7) + 1;
|
||||
const int16_t ls3 = 2*((sc[ib/2] >> 6) & 7) + 1;
|
||||
const int16_t ls4 = 2*((sc[ib/2] >> 9) & 7) + 1;
|
||||
const __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1));
|
||||
const __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3));
|
||||
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
|
||||
const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
|
||||
const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
|
||||
const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
|
||||
|
||||
sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
|
||||
sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
|
||||
|
||||
qs += 8; qh += 4;
|
||||
}
|
||||
|
||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||
accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
|
||||
accum1 += d * sumi1;
|
||||
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.fp16));
|
||||
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
|
||||
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
|
||||
*s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
|
||||
|
||||
#else
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue