ggml : add ggml_is_quantized()

This commit is contained in:
Georgi Gerganov 2023-04-18 21:32:27 +03:00
parent e435b81454
commit fe859297f3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

27
ggml.c
View file

@ -3449,6 +3449,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
};
static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = false,
[GGML_TYPE_F16] = false,
[GGML_TYPE_Q4_0] = true,
[GGML_TYPE_Q4_1] = true,
[GGML_TYPE_Q4_2] = true,
[GGML_TYPE_Q8_0] = true,
[GGML_TYPE_I8] = false,
[GGML_TYPE_I16] = false,
[GGML_TYPE_I32] = false,
};
static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
@ -3709,6 +3722,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
(t0->ne[3] == t1->ne[3]);
}
static inline bool ggml_is_quantized(enum ggml_type type) {
return GGML_IS_QUANTIZED[type];
}
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
return tensor->nb[0] > tensor->nb[1];
}
@ -5830,7 +5847,7 @@ static void ggml_compute_forward_dup_f16(
}
}
}
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
} else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data;
@ -6042,7 +6059,7 @@ static void ggml_compute_forward_dup_f32(
}
}
}
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
} else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data;
@ -6405,7 +6422,7 @@ static void ggml_compute_forward_add_q_f32(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_2);
GGML_ASSERT(ggml_is_quantized(src0->type));
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -10622,7 +10639,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->n_tasks = 1;
size_t cur = 0;
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1 || node->type == GGML_TYPE_Q4_2) {
if (ggml_is_quantized(node->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
}
@ -10634,7 +10651,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0;
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1 || node->src0->type == GGML_TYPE_Q4_2) {
if (ggml_is_quantized(node->src0->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
}