gguf : add array support

This commit is contained in:
Georgi Gerganov 2023-07-27 14:53:07 +03:00
parent d89533dff6
commit d2b6ca13ad
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 135 additions and 20 deletions

64
ggml.c
View file

@ -3698,7 +3698,6 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
};
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = "f32",
[GGML_TYPE_F16] = "f16",
@ -18302,7 +18301,19 @@ struct gguf_str {
char * data;
};
union gguf_value;
static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
[GGUF_TYPE_UINT8] = sizeof(uint8_t),
[GGUF_TYPE_INT8] = sizeof(int8_t),
[GGUF_TYPE_UINT16] = sizeof(uint16_t),
[GGUF_TYPE_INT16] = sizeof(int16_t),
[GGUF_TYPE_UINT32] = sizeof(uint32_t),
[GGUF_TYPE_INT32] = sizeof(int32_t),
[GGUF_TYPE_FLOAT32] = sizeof(float),
[GGUF_TYPE_BOOL] = sizeof(bool),
[GGUF_TYPE_STRING] = sizeof(struct gguf_str),
[GGUF_TYPE_ARRAY] = 0, // undefined
};
static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10");
union gguf_value {
uint8_t uint8;
@ -18320,7 +18331,7 @@ union gguf_value {
enum gguf_type type;
uint32_t n;
union gguf_value * arr;
void * data;
} arr;
};
@ -18457,8 +18468,35 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (&kv->value.bool_, sizeof(kv->value.bool_), file, &offset); break;
case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(&kv->value.str, file, &offset); break;
case GGUF_TYPE_ARRAY:
GGML_ASSERT("gguf: array type not implemented");
break;
{
ok = ok && gguf_fread_el(&kv->value.arr.type, sizeof(kv->value.arr.type), file, &offset);
ok = ok && gguf_fread_el(&kv->value.arr.n, sizeof(kv->value.arr.n), file, &offset);
switch (kv->value.arr.type) {
case GGUF_TYPE_UINT8:
case GGUF_TYPE_INT8:
case GGUF_TYPE_UINT16:
case GGUF_TYPE_INT16:
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32:
case GGUF_TYPE_FLOAT32:
case GGUF_TYPE_BOOL:
{
kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]);
ok = ok && gguf_fread_el(kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], file, &offset);
} break;
case GGUF_TYPE_STRING:
{
kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str));
for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
ok = ok && gguf_fread_str(&((struct gguf_str *) kv->value.arr.data)[j], file, &offset);
}
} break;
case GGUF_TYPE_ARRAY:
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
};
} break;
case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
};
if (!ok) {
@ -18629,6 +18667,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ggml_set_no_alloc(ctx_data, params.no_alloc);
}
fclose(file);
return ctx;
}
@ -18651,6 +18691,20 @@ void gguf_free(struct gguf_context * ctx) {
free(kv->value.str.data);
}
}
if (kv->type == GGUF_TYPE_ARRAY) {
if (kv->value.arr.data) {
if (kv->value.arr.type == GGUF_TYPE_STRING) {
for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
if (str->data) {
free(str->data);
}
}
}
free(kv->value.arr.data);
}
}
}
GGML_ALIGNED_FREE(ctx->header.kv);