ggml : simplify Q8_1 - no need for low / high sums anymore

This commit is contained in:
Georgi Gerganov 2023-05-11 20:11:37 +03:00
parent 695f3963b1
commit 582a39fff5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

69
ggml.c
View file

@ -718,12 +718,11 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
#define QK8_1 32
typedef struct {
float d; // delta
float s0; // d * sum(qs[i]) low
float s1; // d * sum(qs[i]) high
int8_t qs[QK8_1]; // quants
float d; // delta
float s; // d * sum(qs[i])
int8_t qs[QK8_1]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@ -1078,8 +1077,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
y[i].d = d;
int sum0 = 0;
int sum1 = 0;
int sum = 0;
for (int j = 0; j < QK8_1/2; ++j) {
const float v0 = x[i*QK8_1 + 2*j + 0]*id;
@ -1088,12 +1086,11 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
y[i].qs[ j] = v0 + 0.5f;
y[i].qs[QK8_1/2 + j] = v1 + 0.5f;
sum0 += y[i].qs[ j];
sum1 += y[i].qs[QK8_1/2 + j];
sum += y[i].qs[ j];
sum += y[i].qs[QK8_1/2 + j];
}
y[i].s0 = d * sum0;
y[i].s1 = d * sum1;
y[i].s = d * sum;
}
}
@ -1123,11 +1120,9 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
y[i].d = d;
int32x4_t accv0 = vdupq_n_s32(0);
int32x4_t accv1 = vdupq_n_s32(0);
int32x4_t accv = vdupq_n_s32(0);
// low half
for (int j = 0; j < 4; j++) {
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
@ -1136,27 +1131,10 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
accv0 = vaddq_s32(accv0, vi);
accv = vaddq_s32(accv, vi);
}
// high half
for (int j = 4; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[ 2*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[16 + 2*j + 0] = vgetq_lane_s32(vi, 1);
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
accv1 = vaddq_s32(accv1, vi);
}
const int32_t sum0 = vaddvq_s32(accv0);
const int32_t sum1 = vaddvq_s32(accv1);
y[i].s0 = d * sum0;
y[i].s1 = d * sum1;
y[i].s = d * vaddvq_s32(accv);
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
@ -1205,9 +1183,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -1237,8 +1213,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
// Compute the sum of the quants and set y[i].s
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
y[i].s0 = d * hsum_i32_4(s0);
y[i].s1 = d * hsum_i32_4(s1);
y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
@ -2200,7 +2175,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
summs += x0->m * y0->s + x1->m * y1->s;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -2259,7 +2234,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
summs += x[i].m * (y[i].s0 + y[i].s1);
summs += x[i].m * y[i].s;
const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
@ -2292,7 +2267,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
}
*s = sumf;
@ -2545,8 +2520,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const uint8x16_t m4b = vdupq_n_u8(0x0F);
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1);
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
// extract the 5th bit via lookup table ((b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
@ -2632,7 +2607,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const block_q5_1 * restrict x0 = &x[i];
const block_q8_1 * restrict y0 = &y[i];
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
const v128_t m4b = wasm_i8x16_splat(0x0F);
@ -2696,7 +2671,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
@ -2732,7 +2707,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;