ggml : simplify Q8_1 - no need for low / high sums anymore

This commit is contained in:
Georgi Gerganov 2023-05-11 20:11:37 +03:00
parent 695f3963b1
commit 582a39fff5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

69
ggml.c
View file

@ -718,12 +718,11 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
#define QK8_1 32 #define QK8_1 32
typedef struct { typedef struct {
float d; // delta float d; // delta
float s0; // d * sum(qs[i]) low float s; // d * sum(qs[i])
float s1; // d * sum(qs[i]) high int8_t qs[QK8_1]; // quants
int8_t qs[QK8_1]; // quants
} block_q8_1; } block_q8_1;
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@ -1078,8 +1077,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
y[i].d = d; y[i].d = d;
int sum0 = 0; int sum = 0;
int sum1 = 0;
for (int j = 0; j < QK8_1/2; ++j) { for (int j = 0; j < QK8_1/2; ++j) {
const float v0 = x[i*QK8_1 + 2*j + 0]*id; const float v0 = x[i*QK8_1 + 2*j + 0]*id;
@ -1088,12 +1086,11 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
y[i].qs[ j] = v0 + 0.5f; y[i].qs[ j] = v0 + 0.5f;
y[i].qs[QK8_1/2 + j] = v1 + 0.5f; y[i].qs[QK8_1/2 + j] = v1 + 0.5f;
sum0 += y[i].qs[ j]; sum += y[i].qs[ j];
sum1 += y[i].qs[QK8_1/2 + j]; sum += y[i].qs[QK8_1/2 + j];
} }
y[i].s0 = d * sum0; y[i].s = d * sum;
y[i].s1 = d * sum1;
} }
} }
@ -1123,11 +1120,9 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
y[i].d = d; y[i].d = d;
int32x4_t accv0 = vdupq_n_s32(0); int32x4_t accv = vdupq_n_s32(0);
int32x4_t accv1 = vdupq_n_s32(0);
// low half for (int j = 0; j < 8; j++) {
for (int j = 0; j < 4; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id); const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v); const int32x4_t vi = vcvtnq_s32_f32(v);
@ -1136,27 +1131,10 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2); y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3); y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
accv0 = vaddq_s32(accv0, vi); accv = vaddq_s32(accv, vi);
} }
// high half y[i].s = d * vaddvq_s32(accv);
for (int j = 4; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[ 2*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[16 + 2*j + 0] = vgetq_lane_s32(vi, 1);
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
accv1 = vaddq_s32(accv1, vi);
}
const int32_t sum0 = vaddvq_s32(accv0);
const int32_t sum1 = vaddvq_s32(accv1);
y[i].s0 = d * sum0;
y[i].s1 = d * sum1;
} }
#elif defined(__AVX2__) || defined(__AVX__) #elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
@ -1205,9 +1183,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
#if defined(__AVX2__) #if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s // Compute the sum of the quants and set y[i].s
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
// Convert int32 to int16 // Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -1237,8 +1213,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
// Compute the sum of the quants and set y[i].s // Compute the sum of the quants and set y[i].s
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
y[i].s0 = d * hsum_i32_4(s0); y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
y[i].s1 = d * hsum_i32_4(s1);
// Convert int32 to int16 // Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 ); ni0 = _mm_packs_epi32( ni0, ni1 );
@ -2200,7 +2175,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const block_q8_1 * restrict y0 = &y[i + 0]; const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1]; const block_q8_1 * restrict y1 = &y[i + 1];
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); summs += x0->m * y0->s + x1->m * y1->s;
const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -2259,7 +2234,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const float * d0 = &x[i].d; const float * d0 = &x[i].d;
const float * d1 = &y[i].d; const float * d1 = &y[i].d;
summs += x[i].m * (y[i].s0 + y[i].s1); summs += x[i].m * y[i].s;
const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 ); const __m256 d1v = _mm256_broadcast_ss( d1 );
@ -2292,7 +2267,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
} }
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1); sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
} }
*s = sumf; *s = sumf;
@ -2545,8 +2520,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t m4b = vdupq_n_u8(0x0F);
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1); summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
// extract the 5th bit via lookup table ((b) << 4) // extract the 5th bit via lookup table ((b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0)); memcpy(&qh0, x0->qh, sizeof(qh0));
@ -2632,7 +2607,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const block_q5_1 * restrict x0 = &x[i]; const block_q5_1 * restrict x0 = &x[i];
const block_q8_1 * restrict y0 = &y[i]; const block_q8_1 * restrict y0 = &y[i];
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
const v128_t m4b = wasm_i8x16_splat(0x0F); const v128_t m4b = wasm_i8x16_splat(0x0F);
@ -2696,7 +2671,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1); summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh); __m256i bxhi = bytes_from_bits_32(x[i].qh);
@ -2732,7 +2707,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
} }
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1); sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
} }
*s = sumf; *s = sumf;