ggml : use q4_0_q8_0 and q4_2_q8_0
This commit is contained in:
parent
d8bf7207f1
commit
6496b79e8e
1 changed files with 24 additions and 21 deletions
45
ggml.c
45
ggml.c
|
@ -1825,9 +1825,9 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
|
||||
|
@ -1837,7 +1837,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|||
.quantize_row_q = quantize_row_q4_0,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_0_q8_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.dequantize_row_q = dequantize_row_q4_1,
|
||||
|
@ -1851,7 +1851,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|||
.quantize_row_q = quantize_row_q4_2,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_2_q8_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_3] = {
|
||||
.dequantize_row_q = dequantize_row_q4_3,
|
||||
|
@ -2475,7 +2475,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|||
*s = sumf;
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const int nb = n / QK8_1;
|
||||
|
||||
assert(n % QK8_1 == 0);
|
||||
|
@ -2488,17 +2488,14 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
|
|||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
float sum8 = 0;
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_1 * restrict y0 = &y[i + 0];
|
||||
const block_q8_1 * restrict y1 = &y[i + 1];
|
||||
|
||||
sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
||||
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
||||
|
@ -2509,6 +2506,12 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
|
|||
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
||||
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
||||
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
||||
|
@ -2523,21 +2526,21 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
|
|||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
||||
|
||||
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
@ -2549,7 +2552,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
|
|||
#endif
|
||||
}
|
||||
|
||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
|
||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
#elif defined(__AVX2__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
@ -2775,7 +2778,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const int nb = n / QK8_1;
|
||||
|
||||
assert(n % QK8_1 == 0);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue