llama : gguf_file_saver write I32

This commit is contained in:
Georgi Gerganov 2023-08-15 11:31:42 +03:00
parent 9574f41818
commit da424b6699
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 30 additions and 11 deletions

12
ggml.c
View file

@ -19039,16 +19039,20 @@ enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].value.arr.type; return ctx->header.kv[i].value.arr.type;
} }
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i) {
struct gguf_kv * kv = &ctx->header.kv[key_id]; return ((int32_t *) ctx->header.kv[key_id].value.arr.data)[i];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
return str->data;
} }
float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i) { float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i) {
return ((float *) ctx->header.kv[key_id].value.arr.data)[i]; return ((float *) ctx->header.kv[key_id].value.arr.data)[i];
} }
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
struct gguf_kv * kv = &ctx->header.kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
return str->data;
}
int gguf_get_arr_n(struct gguf_context * ctx, int i) { int gguf_get_arr_n(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].value.arr.n; return ctx->header.kv[i].value.arr.n;
} }

3
ggml.h
View file

@ -1751,8 +1751,9 @@ extern "C" {
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);
GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i); GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i);
GGML_API int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i);
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);

View file

@ -737,11 +737,24 @@ struct gguf_file_saver {
file.write_arr<float>(key, type, data); file.write_arr<float>(key, type, data);
} }
void write_kv_arr_i32(const std::string & key, enum gguf_type type, int i, int n_arr) {
std::vector<int32_t> data(n_arr);
for (int j = 0; j < n_arr; ++j) {
int32_t val = gguf_get_arr_i32(ctx, i, j);
data[j] = val;
}
file.write_arr<int32_t>(key, type, data);
}
// re-write the key-value section from the loaded file // re-write the key-value section from the loaded file
void write_kv() { void write_kv() {
const int32_t n_kv = gguf_get_n_kv(ctx); const int32_t n_kv = gguf_get_n_kv(ctx);
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(ctx, i); const char * key = gguf_get_key(ctx, i);
LLAMA_LOG_INFO("%s: writing key '%s'\n", __func__, key);
if (strcmp(key, "general.quantization_version") == 0) { if (strcmp(key, "general.quantization_version") == 0) {
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION); file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION);
} else { } else {
@ -761,12 +774,13 @@ struct gguf_file_saver {
{ {
const gguf_type arr_type = gguf_get_arr_type(ctx, i); const gguf_type arr_type = gguf_get_arr_type(ctx, i);
const int n_arr = gguf_get_arr_n (ctx, i); const int n_arr = gguf_get_arr_n (ctx, i);
if (arr_type == GGUF_TYPE_FLOAT32) {
write_kv_arr_f32(key, arr_type, i, n_arr); switch (arr_type) {
} else if (arr_type == GGUF_TYPE_STRING) { case GGUF_TYPE_FLOAT32: write_kv_arr_f32(key, arr_type, i, n_arr); break;
write_kv_arr_str(key, arr_type, i, n_arr); case GGUF_TYPE_INT32: write_kv_arr_i32(key, arr_type, i, n_arr); break;
} else { case GGUF_TYPE_STRING: write_kv_arr_str(key, arr_type, i, n_arr); break;
throw std::runtime_error("not implemented"); default:
throw std::runtime_error(format("cannot recognize array type for key %s\n", key));
} }
} break; } break;
default: default: