ggml : fix q4_1

This commit is contained in:
Georgi Gerganov 2024-07-18 10:00:03 +03:00
parent 90e8f81556
commit e5e7a24ee7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -4409,6 +4409,10 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
return;
}
#endif
int ib = 0;
float sumf = 0;
// TODO: add WASM SIMD
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -4416,13 +4420,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
float summs = 0;
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
for (; ib + 1 < nb; ib += 2) {
const block_q4_1 * restrict x0 = &x[ib + 0];
const block_q4_1 * restrict x1 = &x[ib + 1];
const block_q8_1 * restrict y0 = &y[ib + 0];
const block_q8_1 * restrict y1 = &y[ib + 1];
summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
@ -4451,7 +4453,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -4459,11 +4461,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
float summs = 0;
// Main loop
for (int i = 0; i < nb; ++i) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
const float d1 = GGML_FP16_TO_FP32(y[i].d);
for (; ib < nb; ++ib) {
const float d0 = GGML_FP16_TO_FP32(x[ib].d);
const float d1 = GGML_FP16_TO_FP32(y[ib].d);
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
const __m256 d0v = _mm256_set1_ps( d0 );
const __m256 d1v = _mm256_set1_ps( d1 );
@ -4472,8 +4474,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i qx = bytes_from_nibbles_32(x[i].qs);
const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs );
const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
@ -4485,18 +4487,18 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
#endif
}
*s = hsum_float_8(acc) + summs;
sumf = hsum_float_8(acc) + summs;
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
size_t vl = __riscv_vsetvl_e8m1(qk/2);
for (int i = 0; i < nb; i++) {
for (; ib < nb; ++ib) {
// load elements
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
// mask and store lower part of x, and then upper part
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
@ -4515,11 +4517,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
}
*s = sumf;
#elif defined(__POWER9_VECTOR__)
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector signed int v0 = vec_splats((int32_t)0);
@ -4528,21 +4528,21 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
vector float vsumf0 = vec_splats(0.0f);
#pragma GCC unroll 4
for (int i = 0; i < nb; i++) {
__builtin_prefetch(x[i].qs, 0, 1);
__builtin_prefetch(y[i].qs, 0, 1);
for (; ib < nb; ++ib) {
__builtin_prefetch(x[ib].qs, 0, 1);
__builtin_prefetch(y[ib].qs, 0, 1);
vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
vector float vd = vec_mul(vxd, vyd);
vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m));
vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.0f, 0.0f, 0.0f};
vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
vsumf0 = vec_madd(vxmin, vys, vsumf0);
vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
vector signed char q8y0 = vec_xl( 0, y[i].qs);
vector signed char q8y1 = vec_xl(16, y[i].qs);
vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
vector signed char q8y0 = vec_xl( 0, y[ib].qs);
vector signed char q8y1 = vec_xl(16, y[ib].qs);
vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
@ -4558,7 +4558,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
*s = vec_extract(vsumf0, 0);
sumf = vec_extract(vsumf0, 0);
#elif defined(__loongarch_asx)
// Initialize accumulator with zeros
@ -4567,11 +4567,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
float summs = 0;
// Main loop
for (int i = 0; i < nb; ++i) {
const float d0 = GGML_FP16_TO_FP32(x[i].d);
const float d1 = GGML_FP16_TO_FP32(y[i].d);
for (; ib < nb; ++ib) {
const float d0 = GGML_FP16_TO_FP32(x[ib].d);
const float d1 = GGML_FP16_TO_FP32(y[ib].d);
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
@ -4580,8 +4580,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i qx = bytes_from_nibbles_32(x[i].qs);
const __m256i qy = __lasx_xvld( (const __m256i *)y[i].qs, 0);
const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
@ -4589,27 +4589,22 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
acc = __lasx_xvfmadd_s( d0d1, xy, acc );
}
*s = hsum_float_8(acc) + summs;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
sumf = hsum_float_8(acc) + summs;
#endif
for (; ib < nb; ++ib) {
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F);
const int v1 = (x[i].qs[j] >> 4);
const int v0 = (x[ib].qs[j] & 0x0F);
const int v1 = (x[ib].qs[j] >> 4);
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
}
*s = sumf;
#endif
}
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {