ggml : add ggml_is_quantized()
This commit is contained in:
parent
e435b81454
commit
fe859297f3
1 changed files with 22 additions and 5 deletions
27
ggml.c
27
ggml.c
|
@ -3449,6 +3449,19 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
||||||
};
|
};
|
||||||
static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
|
static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
|
||||||
|
|
||||||
|
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
||||||
|
[GGML_TYPE_F32] = false,
|
||||||
|
[GGML_TYPE_F16] = false,
|
||||||
|
[GGML_TYPE_Q4_0] = true,
|
||||||
|
[GGML_TYPE_Q4_1] = true,
|
||||||
|
[GGML_TYPE_Q4_2] = true,
|
||||||
|
[GGML_TYPE_Q8_0] = true,
|
||||||
|
[GGML_TYPE_I8] = false,
|
||||||
|
[GGML_TYPE_I16] = false,
|
||||||
|
[GGML_TYPE_I32] = false,
|
||||||
|
};
|
||||||
|
static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
|
||||||
|
|
||||||
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"NONE",
|
"NONE",
|
||||||
|
|
||||||
|
@ -3709,6 +3722,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
||||||
(t0->ne[3] == t1->ne[3]);
|
(t0->ne[3] == t1->ne[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool ggml_is_quantized(enum ggml_type type) {
|
||||||
|
return GGML_IS_QUANTIZED[type];
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
||||||
return tensor->nb[0] > tensor->nb[1];
|
return tensor->nb[0] > tensor->nb[1];
|
||||||
}
|
}
|
||||||
|
@ -5830,7 +5847,7 @@ static void ggml_compute_forward_dup_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
|
} else if (ggml_is_quantized(dst->type)) {
|
||||||
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
||||||
|
@ -6042,7 +6059,7 @@ static void ggml_compute_forward_dup_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1 || dst->type == GGML_TYPE_Q4_2) {
|
} else if (ggml_is_quantized(dst->type)) {
|
||||||
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
||||||
|
@ -6405,7 +6422,7 @@ static void ggml_compute_forward_add_q_f32(
|
||||||
GGML_ASSERT(nb1 <= nb2);
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_2);
|
GGML_ASSERT(ggml_is_quantized(src0->type));
|
||||||
GGML_ASSERT(dst->type == src0->type);
|
GGML_ASSERT(dst->type == src0->type);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -10622,7 +10639,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
node->n_tasks = 1;
|
node->n_tasks = 1;
|
||||||
|
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1 || node->type == GGML_TYPE_Q4_2) {
|
if (ggml_is_quantized(node->type)) {
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10634,7 +10651,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
|
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
|
|
||||||
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1 || node->src0->type == GGML_TYPE_Q4_2) {
|
if (ggml_is_quantized(node->src0->type)) {
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue