ggml : simplify Q8_1 - no need for low / high sums anymore
This commit is contained in:
parent
695f3963b1
commit
582a39fff5
1 changed files with 22 additions and 47 deletions
65
ggml.c
65
ggml.c
|
@ -719,11 +719,10 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
|
|||
#define QK8_1 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float s0; // d * sum(qs[i]) low
|
||||
float s1; // d * sum(qs[i]) high
|
||||
float s; // d * sum(qs[i])
|
||||
int8_t qs[QK8_1]; // quants
|
||||
} block_q8_1;
|
||||
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
||||
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
|
@ -1078,8 +1077,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
|
|||
|
||||
y[i].d = d;
|
||||
|
||||
int sum0 = 0;
|
||||
int sum1 = 0;
|
||||
int sum = 0;
|
||||
|
||||
for (int j = 0; j < QK8_1/2; ++j) {
|
||||
const float v0 = x[i*QK8_1 + 2*j + 0]*id;
|
||||
|
@ -1088,12 +1086,11 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
|
|||
y[i].qs[ j] = v0 + 0.5f;
|
||||
y[i].qs[QK8_1/2 + j] = v1 + 0.5f;
|
||||
|
||||
sum0 += y[i].qs[ j];
|
||||
sum1 += y[i].qs[QK8_1/2 + j];
|
||||
sum += y[i].qs[ j];
|
||||
sum += y[i].qs[QK8_1/2 + j];
|
||||
}
|
||||
|
||||
y[i].s0 = d * sum0;
|
||||
y[i].s1 = d * sum1;
|
||||
y[i].s = d * sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1123,11 +1120,9 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|||
|
||||
y[i].d = d;
|
||||
|
||||
int32x4_t accv0 = vdupq_n_s32(0);
|
||||
int32x4_t accv1 = vdupq_n_s32(0);
|
||||
int32x4_t accv = vdupq_n_s32(0);
|
||||
|
||||
// low half
|
||||
for (int j = 0; j < 4; j++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[j], id);
|
||||
const int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
|
||||
|
@ -1136,27 +1131,10 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|||
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
accv0 = vaddq_s32(accv0, vi);
|
||||
accv = vaddq_s32(accv, vi);
|
||||
}
|
||||
|
||||
// high half
|
||||
for (int j = 4; j < 8; j++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[j], id);
|
||||
const int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
|
||||
y[i].qs[ 2*j + 0] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[16 + 2*j + 0] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
accv1 = vaddq_s32(accv1, vi);
|
||||
}
|
||||
|
||||
const int32_t sum0 = vaddvq_s32(accv0);
|
||||
const int32_t sum1 = vaddvq_s32(accv1);
|
||||
|
||||
y[i].s0 = d * sum0;
|
||||
y[i].s1 = d * sum1;
|
||||
y[i].s = d * vaddvq_s32(accv);
|
||||
}
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
@ -1205,9 +1183,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|||
|
||||
#if defined(__AVX2__)
|
||||
// Compute the sum of the quants and set y[i].s
|
||||
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
||||
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
|
||||
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
|
||||
y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||
|
@ -1237,8 +1213,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
|||
// Compute the sum of the quants and set y[i].s
|
||||
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
|
||||
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
|
||||
y[i].s0 = d * hsum_i32_4(s0);
|
||||
y[i].s1 = d * hsum_i32_4(s1);
|
||||
y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
|
||||
|
||||
// Convert int32 to int16
|
||||
ni0 = _mm_packs_epi32( ni0, ni1 );
|
||||
|
@ -2200,7 +2175,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
const block_q8_1 * restrict y0 = &y[i + 0];
|
||||
const block_q8_1 * restrict y1 = &y[i + 1];
|
||||
|
||||
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
|
||||
summs += x0->m * y0->s + x1->m * y1->s;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
|
||||
|
@ -2259,7 +2234,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
const float * d0 = &x[i].d;
|
||||
const float * d1 = &y[i].d;
|
||||
|
||||
summs += x[i].m * (y[i].s0 + y[i].s1);
|
||||
summs += x[i].m * y[i].s;
|
||||
|
||||
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
||||
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
||||
|
@ -2292,7 +2267,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
|
||||
}
|
||||
|
||||
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
|
||||
sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
@ -2545,8 +2520,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
|
||||
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
||||
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1);
|
||||
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
|
||||
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
|
||||
|
||||
// extract the 5th bit via lookup table ((b) << 4)
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
|
@ -2632,7 +2607,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
const block_q5_1 * restrict x0 = &x[i];
|
||||
const block_q8_1 * restrict y0 = &y[i];
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
|
||||
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
|
||||
|
||||
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
||||
|
||||
|
@ -2696,7 +2671,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
for (int i = 0; i < nb; i++) {
|
||||
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
|
@ -2732,7 +2707,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
|
||||
}
|
||||
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue