diff --git a/ggml.c b/ggml.c index c38a389aa..24cfd0009 100644 --- a/ggml.c +++ b/ggml.c @@ -3449,6 +3449,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { }; static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated"); +static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { + [GGML_TYPE_F32] = false, + [GGML_TYPE_F16] = false, + [GGML_TYPE_Q4_0] = true, + [GGML_TYPE_Q4_1] = true, + [GGML_TYPE_Q4_2] = true, + [GGML_TYPE_Q8_0] = true, + [GGML_TYPE_I8] = false, + [GGML_TYPE_I16] = false, + [GGML_TYPE_I32] = false, +}; +static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated"); + static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -3709,6 +3722,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct (t0->ne[3] == t1->ne[3]); } +static inline bool ggml_is_quantized(enum ggml_type type) { + return GGML_IS_QUANTIZED[type]; +} + static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } @@ -5830,7 +5847,7 @@ static void ggml_compute_forward_dup_f16( } } } - } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) { + } else if (ggml_is_quantized(dst->type)) { quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; size_t id = 0; uint8_t * dst_ptr = (uint8_t *) dst->data; @@ -6042,7 +6059,7 @@ static void ggml_compute_forward_dup_f32( } } } - } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) { + } else if (ggml_is_quantized(dst->type)) { quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; size_t id = 0; uint8_t * dst_ptr = (uint8_t *) dst->data; @@ -6405,7 +6422,7 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_2); + GGML_ASSERT(ggml_is_quantized(src0->type)); GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -10622,7 +10639,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = 1; size_t cur = 0; - if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1 || node->type == GGML_TYPE_Q4_2) { + if (ggml_is_quantized(node->type)) { cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0]; } @@ -10634,7 +10651,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; - if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1 || node->src0->type == GGML_TYPE_Q4_2) { + if (ggml_is_quantized(node->src0->type)) { cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads; }