iq1_m: make it work for QK_K = 64 (scalar and AVX2)
This commit is contained in:
parent
e1939bc869
commit
5c953a1a15
1 changed files with 16 additions and 2 deletions
|
@ -9857,7 +9857,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
#elif defined __AVX2__
|
#elif defined __AVX2__
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
const __m256i mask = _mm256_set1_epi16(0xf);
|
||||||
|
#else
|
||||||
const __m256i mask = _mm256_set1_epi16(0x7);
|
const __m256i mask = _mm256_set1_epi16(0x7);
|
||||||
|
#endif
|
||||||
const __m256i mone = _mm256_set1_epi16(1);
|
const __m256i mone = _mm256_set1_epi16(1);
|
||||||
|
|
||||||
__m256 accum1 = _mm256_setzero_ps();
|
__m256 accum1 = _mm256_setzero_ps();
|
||||||
|
@ -9869,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
const uint8_t * qh = x[i].qh;
|
const uint8_t * qh = x[i].qh;
|
||||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
|
|
||||||
|
#if QK_K != 64
|
||||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
#endif
|
||||||
|
|
||||||
__m256i sumi1 = _mm256_setzero_si256();
|
__m256i sumi1 = _mm256_setzero_si256();
|
||||||
__m256i sumi2 = _mm256_setzero_si256();
|
__m256i sumi2 = _mm256_setzero_si256();
|
||||||
|
@ -9899,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
|
|
||||||
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
|
const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
|
||||||
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
|
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
|
||||||
|
#if QK_K == 64
|
||||||
|
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
|
||||||
|
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
|
||||||
|
#else
|
||||||
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
|
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
|
||||||
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
|
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
|
||||||
|
#endif
|
||||||
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
|
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
|
||||||
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
|
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
|
||||||
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
|
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
|
||||||
|
@ -9914,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
||||||
qs += 8; qh += 4;
|
qs += 8; qh += 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
|
||||||
|
#else
|
||||||
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
|
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
|
||||||
|
#endif
|
||||||
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
|
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
|
||||||
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
|
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
|
||||||
|
|
||||||
|
@ -12026,11 +12041,10 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
|
||||||
|
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
||||||
#else
|
#endif
|
||||||
memset(y[ibl].qs, 0, QK_K/8);
|
memset(y[ibl].qs, 0, QK_K/8);
|
||||||
memset(y[ibl].qh, 0, QK_K/16);
|
memset(y[ibl].qh, 0, QK_K/16);
|
||||||
memset(y[ibl].scales, 0, QK_K/32);
|
memset(y[ibl].scales, 0, QK_K/32);
|
||||||
#endif
|
|
||||||
|
|
||||||
float max_scale = 0;
|
float max_scale = 0;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue