ggml : fix bug - using wrong block type

This commit is contained in:
Georgi Gerganov 2023-04-25 22:28:26 +03:00
parent 6e0f0b6ff1
commit 46fc696dea
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

16
ggml.c
View file

@ -1836,7 +1836,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.dequantize_row_q = dequantize_row_q4_0,
.quantize_row_q = quantize_row_q4_0,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
.quantize_row_q_dot = quantize_row_q8_1,
.quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
},
[GGML_TYPE_Q4_1] = {
@ -1850,7 +1850,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.dequantize_row_q = dequantize_row_q4_2,
.quantize_row_q = quantize_row_q4_2,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
.quantize_row_q_dot = quantize_row_q8_1,
.quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
},
[GGML_TYPE_Q4_3] = {
@ -2482,7 +2482,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx;
const block_q8_1 * restrict y = vy;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -2491,8 +2491,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
@ -2786,7 +2786,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
assert(QK8_1 == 2*QK4_2);
const block_q4_2 * restrict x = vx;
const block_q8_1 * restrict y = vy;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -2798,8 +2798,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);