From da424b66992e9d9ac57553fe089215f5dd9e0dbe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Aug 2023 11:31:42 +0300 Subject: [PATCH] llama : gguf_file_saver write I32 --- ggml.c | 12 ++++++++---- ggml.h | 3 ++- gguf-llama.cpp | 26 ++++++++++++++++++++------ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index c8fa60328..cdba137da 100644 --- a/ggml.c +++ b/ggml.c @@ -19039,16 +19039,20 @@ enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) { return ctx->header.kv[i].value.arr.type; } -const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { - struct gguf_kv * kv = &ctx->header.kv[key_id]; - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; - return str->data; +int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i) { + return ((int32_t *) ctx->header.kv[key_id].value.arr.data)[i]; } float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i) { return ((float *) ctx->header.kv[key_id].value.arr.data)[i]; } +const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { + struct gguf_kv * kv = &ctx->header.kv[key_id]; + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; + return str->data; +} + int gguf_get_arr_n(struct gguf_context * ctx, int i) { return ctx->header.kv[i].value.arr.n; } diff --git a/ggml.h b/ggml.h index 9a9c7ab39..79bda4538 100644 --- a/ggml.h +++ b/ggml.h @@ -1751,8 +1751,9 @@ extern "C" { GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); - GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i); GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i); + GGML_API int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i); + GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i); GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); diff --git a/gguf-llama.cpp b/gguf-llama.cpp index bff6213ca..76f65b71f 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -737,11 +737,24 @@ struct gguf_file_saver { file.write_arr(key, type, data); } + void write_kv_arr_i32(const std::string & key, enum gguf_type type, int i, int n_arr) { + std::vector data(n_arr); + + for (int j = 0; j < n_arr; ++j) { + int32_t val = gguf_get_arr_i32(ctx, i, j); + data[j] = val; + } + + file.write_arr(key, type, data); + } + // re-write the key-value section from the loaded file void write_kv() { const int32_t n_kv = gguf_get_n_kv(ctx); for (int i = 0; i < n_kv; ++i) { const char * key = gguf_get_key(ctx, i); + LLAMA_LOG_INFO("%s: writing key '%s'\n", __func__, key); + if (strcmp(key, "general.quantization_version") == 0) { file.write_val("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION); } else { @@ -761,12 +774,13 @@ struct gguf_file_saver { { const gguf_type arr_type = gguf_get_arr_type(ctx, i); const int n_arr = gguf_get_arr_n (ctx, i); - if (arr_type == GGUF_TYPE_FLOAT32) { - write_kv_arr_f32(key, arr_type, i, n_arr); - } else if (arr_type == GGUF_TYPE_STRING) { - write_kv_arr_str(key, arr_type, i, n_arr); - } else { - throw std::runtime_error("not implemented"); + + switch (arr_type) { + case GGUF_TYPE_FLOAT32: write_kv_arr_f32(key, arr_type, i, n_arr); break; + case GGUF_TYPE_INT32: write_kv_arr_i32(key, arr_type, i, n_arr); break; + case GGUF_TYPE_STRING: write_kv_arr_str(key, arr_type, i, n_arr); break; + default: + throw std::runtime_error(format("cannot recognize array type for key %s\n", key)); } } break; default: