diff --git a/ggml.c b/ggml.c index 80cd31b92..83751dd0a 100644 --- a/ggml.c +++ b/ggml.c @@ -1838,6 +1838,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, }, [GGML_TYPE_Q4_1] = { .dequantize_row_q = dequantize_row_q4_1, @@ -1845,6 +1846,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, .quantize_row_q_dot = quantize_row_q8_1, .vec_dot_q = ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q4_2] = { .dequantize_row_q = dequantize_row_q4_2, @@ -1852,6 +1854,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference, .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, }, [GGML_TYPE_Q4_3] = { .dequantize_row_q = dequantize_row_q4_3, @@ -1859,6 +1862,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, .quantize_row_q_dot = quantize_row_q8_1, .vec_dot_q = ggml_vec_dot_q4_3_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q8_0] = { .dequantize_row_q = dequantize_row_q8_0, @@ -1866,6 +1870,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference, .quantize_row_q_dot = quantize_row_q8_0, .vec_dot_q = ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, }, [GGML_TYPE_Q8_1] = { .dequantize_row_q = NULL, // TODO @@ -1873,6 +1878,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference, .quantize_row_q_dot = quantize_row_q8_1, .vec_dot_q = NULL, // TODO + .vec_dot_type = GGML_TYPE_Q8_1, }, }; @@ -2476,9 +2482,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t } static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_1; + const int nb = n / QK8_0; - assert(n % QK8_1 == 0); + assert(n % QK8_0 == 0); assert(nb % 2 == 0); const block_q4_0 * restrict x = vx; @@ -2627,7 +2633,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8_t * restrict p1 = y[i].qs; int sumi = 0; - for (int j = 0; j < QK8_1/2; j++) { + for (int j = 0; j < QK8_0/2; j++) { const uint8_t v0 = p0[j]; const int i0 = (int8_t) (v0 & 0xf) - 8; @@ -2779,11 +2785,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * } static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_1; + const int nb = n / QK8_0; - assert(n % QK8_1 == 0); + assert(n % QK8_0 == 0); assert(nb % 2 == 0); - assert(QK8_1 == 2*QK4_2); + assert(QK8_0 == 2*QK4_2); const block_q4_2 * restrict x = vx; const block_q8_0 * restrict y = vy; @@ -2908,7 +2914,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * int sumi_0 = 0; int sumi_1 = 0; - for (int j = 0; j < QK8_1/4; j++) { + for (int j = 0; j < QK8_0/4; j++) { const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; @@ -2921,8 +2927,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const int i2_0 = y0[2*j + 0]; const int i3_0 = y0[2*j + 1]; - const int i2_1 = y0[2*(j + QK8_1/4) + 0]; - const int i3_1 = y0[2*(j + QK8_1/4) + 1]; + const int i2_1 = y0[2*(j + QK8_0/4) + 0]; + const int i3_1 = y0[2*(j + QK8_0/4) + 1]; sumi_0 += i0_0*i2_0 + i1_0*i3_0; sumi_1 += i0_1*i2_1 + i1_1*i3_1; @@ -8099,6 +8105,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); @@ -8235,7 +8242,7 @@ static void ggml_compute_forward_mul_mat_q_f32( if (params->type == GGML_TASK_INIT) { char * wdata = params->wdata; - const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_1]/GGML_BLCK_SIZE[GGML_TYPE_Q8_1]; + const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -8266,7 +8273,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ir1 = MIN(ir0 + dr, nr); void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_1]/GGML_BLCK_SIZE[GGML_TYPE_Q8_1]; + const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int ir = ir0; ir < ir1; ++ir) { // src0 indices @@ -11069,7 +11076,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else #endif { - cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_1]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_1]; + const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type; + cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q]; } } else { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 11d255ef9..8300a0c62 100644 --- a/ggml.h +++ b/ggml.h @@ -878,6 +878,7 @@ extern "C" { quantize_row_q_t quantize_row_q_reference; quantize_row_q_t quantize_row_q_dot; vec_dot_q_t vec_dot_q; + enum ggml_type vec_dot_type; } quantize_fns_t; quantize_fns_t ggml_internal_get_quantize_fn(size_t i);