ggml : fix q8_0

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-18 10:53:03 +03:00
parent 79b95e3420
commit 62a3185ca6
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -5395,18 +5395,20 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
return; return;
} }
#endif #endif
int ib = 0;
float sumf = 0;
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
if (svcntb() == QK8_0) { if (svcntb() == QK8_0) {
svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb for (; ib + 1 < nb; ib += 2) {
const block_q8_0 * restrict x0 = &x[ib + 0];
for (int i = 0; i < nb; i += 2) { const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict x0 = &x[i + 0]; const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y1 = &y[ib + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
// load x // load x
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
@ -5420,21 +5422,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
} }
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
return;
} }
#endif #elif defined(__ARM_NEON)
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb for (; ib + 1 < nb; ib += 2) {
const block_q8_0 * restrict x0 = &x[ib + 0];
for (int i = 0; i < nb; i += 2) { const block_q8_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict x0 = &x[i + 0]; const block_q8_0 * restrict y0 = &y[ib + 0];
const block_q8_0 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y1 = &y[ib + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const int8x16_t x0_0 = vld1q_s8(x0->qs); const int8x16_t x0_0 = vld1q_s8(x0->qs);
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
@ -5456,17 +5454,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
} }
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__) || defined(__AVX__) #elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
// Main loop // Main loop
for (int i = 0; i < nb; ++i) { for (; ib < nb; ++ib) {
// Compute combined scale for the block // Compute combined scale for the block
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
__m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
__m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
const __m256 q = mul_sum_i8_pairs_float(qx, qy); const __m256 q = mul_sum_i8_pairs_float(qx, qy);
@ -5478,15 +5476,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
*s = hsum_float_8(acc); sumf = hsum_float_8(acc);
#elif defined(__riscv_v_intrinsic) #elif defined(__riscv_v_intrinsic)
float sumf = 0.0; float sumf = 0.0;
size_t vl = __riscv_vsetvl_e8m1(qk); size_t vl = __riscv_vsetvl_e8m1(qk);
for (int i = 0; i < nb; i++) { for (; ib < nb; ++ib) {
// load elements // load elements
vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
@ -5495,28 +5493,25 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
} }
*s = sumf;
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
const vector signed int v0 = vec_splats((int32_t)0); const vector signed int v0 = vec_splats((int32_t)0);
vector float vsumf0 = vec_splats(0.0f); vector float vsumf0 = vec_splats(0.0f);
#pragma GCC unroll 8 #pragma GCC unroll 8
for (int i = 0; i < nb; i++) { for (; ib < nb; ++ib) {
__builtin_prefetch(x[i].qs, 0, 1); __builtin_prefetch(x[ib].qs, 0, 1);
__builtin_prefetch(y[i].qs, 0, 1); __builtin_prefetch(y[ib].qs, 0, 1);
vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
vector float vd = vec_mul(vxd, vyd); vector float vd = vec_mul(vxd, vyd);
vector signed char q8x0 = vec_xl( 0, x[i].qs); vector signed char q8x0 = vec_xl( 0, x[ib].qs);
vector signed char q8x1 = vec_xl(16, x[i].qs); vector signed char q8x1 = vec_xl(16, x[ib].qs);
vector signed char q8y0 = vec_xl( 0, y[i].qs); vector signed char q8y0 = vec_xl( 0, y[ib].qs);
vector signed char q8y1 = vec_xl(16, y[i].qs); vector signed char q8y1 = vec_xl(16, y[ib].qs);
vector signed short qv0 = vec_mule(q8x0, q8y0); vector signed short qv0 = vec_mule(q8x0, q8y0);
vector signed short qv1 = vec_mulo(q8x0, q8y0); vector signed short qv1 = vec_mulo(q8x0, q8y0);
@ -5539,18 +5534,18 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
*s = vec_extract(vsumf0, 0); sumf = vec_extract(vsumf0, 0);
#elif defined(__loongarch_asx) #elif defined(__loongarch_asx)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = (__m256)__lasx_xvldi(0); __m256 acc = (__m256)__lasx_xvldi(0);
// Main loop // Main loop
for (int i = 0; i < nb; ++i) { for (; ib < nb; ++ib) {
// Compute combined scale for the block // Compute combined scale for the block
const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
__m256i qx = __lasx_xvld((const __m256i *)x[i].qs, 0); __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
__m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
const __m256 q = mul_sum_i8_pairs_float(qx, qy); const __m256 q = mul_sum_i8_pairs_float(qx, qy);
@ -5558,24 +5553,19 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
acc = __lasx_xvfmadd_s( d, q, acc ); acc = __lasx_xvfmadd_s( d, q, acc );
} }
*s = hsum_float_8(acc); sumf = hsum_float_8(acc);
#endif
#else for (; ib < nb; ++ib) {
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0; int sumi = 0;
for (int j = 0; j < qk; j++) { for (int j = 0; j < qk; j++) {
sumi += x[i].qs[j]*y[i].qs[j]; sumi += x[ib].qs[j]*y[ib].qs[j];
} }
sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
} }
*s = sumf; *s = sumf;
#endif
} }
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {