diff --git a/ggml-quants.c b/ggml-quants.c index e4db8940d..f2e6c4bd1 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9857,7 +9857,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void #elif defined __AVX2__ +#if QK_K == 64 + const __m256i mask = _mm256_set1_epi16(0xf); +#else const __m256i mask = _mm256_set1_epi16(0x7); +#endif const __m256i mone = _mm256_set1_epi16(1); __m256 accum1 = _mm256_setzero_ps(); @@ -9869,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const uint8_t * qh = x[i].qh; const uint16_t * sc = (const uint16_t *)x[i].scales; +#if QK_K != 64 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); @@ -9899,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const __m256i dot3 = mul_add_epi8(delta1, q8b_1); const __m256i dot4 = mul_add_epi8(delta2, q8b_2); +#if QK_K == 64 + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8)); +#else __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); +#endif scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); @@ -9914,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void qs += 8; qh += 4; } +#if QK_K == 64 + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); +#else const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); +#endif accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); @@ -12026,11 +12041,10 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy #if QK_K == 64 y[ibl].d = GGML_FP32_TO_FP16(0.f); -#else +#endif memset(y[ibl].qs, 0, QK_K/8); memset(y[ibl].qh, 0, QK_K/16); memset(y[ibl].scales, 0, QK_K/32); -#endif float max_scale = 0;