Adding scalar and AVX2 Q2_K dot
This commit is contained in:
parent
b439efb712
commit
8516fdf728
2 changed files with 117 additions and 2 deletions
2
ggml.c
2
ggml.c
|
@ -1571,7 +1571,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|||
.quantize_row_q = quantize_row_q2_K,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_K,
|
||||
.vec_dot_q = NULL, //ggml_vec_dot_q2_K_q8_K,
|
||||
.vec_dot_q = ggml_vec_dot_q2_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
[GGML_TYPE_Q3_K] = {
|
||||
|
|
117
k_quants.c
117
k_quants.c
|
@ -1002,6 +1002,122 @@ static inline __m128i get_scale_shuffle(int i) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
||||
const block_q2_K * restrict x = vx;
|
||||
const block_q8_K * restrict y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef __AVX2__
|
||||
|
||||
const __m256i m3 = _mm256_set1_epi8(3);
|
||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
const uint8_t * restrict q2 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
const __m256i mins = _mm256_cvtepi8_epi16(mins8);
|
||||
const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
|
||||
|
||||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
|
||||
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
|
||||
const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
|
||||
const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
|
||||
const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
|
||||
|
||||
__m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
||||
__m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
||||
__m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
|
||||
__m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
|
||||
|
||||
p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
|
||||
p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
|
||||
p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
|
||||
p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
|
||||
|
||||
p0 = _mm256_add_epi32(p0, p1);
|
||||
p2 = _mm256_add_epi32(p2, p3);
|
||||
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
|
||||
}
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const uint8_t * q2 = x[i].qs;
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * sc = x[i].scales;
|
||||
|
||||
int summs = 0;
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
summs += y[i].bsums[j] * (sc[j] >> 4);
|
||||
}
|
||||
|
||||
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||
|
||||
int isum = 0;
|
||||
int is = 0;
|
||||
int d;
|
||||
for (int k = 0; k < QK_K/128; ++k) {
|
||||
int shift = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
d = sc[is++] & 0xF;
|
||||
int isuml = 0;
|
||||
for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
||||
isum += d * isuml;
|
||||
d = sc[is++] & 0xF;
|
||||
isuml = 0;
|
||||
for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
||||
isum += d * isuml;
|
||||
shift += 2;
|
||||
q8 += 32;
|
||||
}
|
||||
q2 += 32;
|
||||
}
|
||||
sumf += dall * isum - dmin * summs;
|
||||
}
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % QK_K == 0);
|
||||
|
||||
|
@ -1169,7 +1285,6 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
int is = 0;
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
// load low 2 bits
|
||||
const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue