iq1_m: make it work for QK_K = 64 (scalar and AVX2)

This commit is contained in:
Iwan Kawrakow 2024-03-26 20:03:11 +02:00
parent e1939bc869
commit 5c953a1a15

View file

@ -9857,7 +9857,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
#elif defined __AVX2__ #elif defined __AVX2__
#if QK_K == 64
const __m256i mask = _mm256_set1_epi16(0xf);
#else
const __m256i mask = _mm256_set1_epi16(0x7); const __m256i mask = _mm256_set1_epi16(0x7);
#endif
const __m256i mone = _mm256_set1_epi16(1); const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps(); __m256 accum1 = _mm256_setzero_ps();
@ -9869,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const uint8_t * qh = x[i].qh; const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales; const uint16_t * sc = (const uint16_t *)x[i].scales;
#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
#endif
__m256i sumi1 = _mm256_setzero_si256(); __m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256();
@ -9899,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const __m256i dot3 = mul_add_epi8(delta1, q8b_1); const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
const __m256i dot4 = mul_add_epi8(delta2, q8b_2); const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
#if QK_K == 64
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
#else
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
#endif
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
@ -9914,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
qs += 8; qh += 4; qs += 8; qh += 4;
} }
#if QK_K == 64
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
#else
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
#endif
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
@ -12026,11 +12041,10 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
#if QK_K == 64 #if QK_K == 64
y[ibl].d = GGML_FP32_TO_FP16(0.f); y[ibl].d = GGML_FP32_TO_FP16(0.f);
#else #endif
memset(y[ibl].qs, 0, QK_K/8); memset(y[ibl].qs, 0, QK_K/8);
memset(y[ibl].qh, 0, QK_K/16); memset(y[ibl].qh, 0, QK_K/16);
memset(y[ibl].scales, 0, QK_K/32); memset(y[ibl].scales, 0, QK_K/32);
#endif
float max_scale = 0; float max_scale = 0;