diff --git a/ggml.c b/ggml.c index 0130afa6f..b5dc3fff4 100644 --- a/ggml.c +++ b/ggml.c @@ -6157,12 +6157,12 @@ typedef struct { } quantize_fns_t; static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { - [GGML_TYPE_Q4_0] = { + [GGML_TYPE_Q4_0] = { .dequantize_row_q = dequantize_row_q4_0, .quantize_row_q = quantize_row_q4_0, .vec_dot_q = ggml_vec_dot_q4_0, }, - [GGML_TYPE_Q4_1] = { + [GGML_TYPE_Q4_1] = { .dequantize_row_q = dequantize_row_q4_1, .quantize_row_q = quantize_row_q4_1, .vec_dot_q = ggml_vec_dot_q4_1, @@ -6218,6 +6218,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); GGML_ASSERT(nb10 == sizeof(float)); @@ -8952,8 +8953,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; - if (node->src0->type == GGML_TYPE_F16 && - node->src1->type == GGML_TYPE_F32) { + if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing @@ -8968,11 +8968,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #else cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); #endif - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { + } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; - } else if (quantize_fns[node->src0->type].vec_dot_q && - node->src1->type == GGML_TYPE_F32) { + } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1;