llama : simplify gguf_file_saver

This commit is contained in:
Georgi Gerganov 2023-08-15 11:09:26 +03:00
parent 66ce19aecb
commit c9c0b758d4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 31 additions and 65 deletions

View file

@ -74,7 +74,9 @@ int main(int argc, char ** argv) {
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
// example, we will just stop the loop once this cache is full or once an end of stream is detected. // example, we will just stop the loop once this cache is full or once an end of stream is detected.
while (llama_get_kv_cache_token_count(ctx) < max_context_size) { const int n_gen = std::min(32, max_context_size);
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
// evaluate the transformer // evaluate the transformer
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
@ -114,7 +116,6 @@ int main(int argc, char ** argv) {
// push this new token for next evaluation // push this new token for next evaluation
tokens_list.push_back(new_token_id); tokens_list.push_back(new_token_id);
} }
llama_free(ctx); llama_free(ctx);
@ -122,5 +123,7 @@ int main(int argc, char ** argv) {
llama_backend_free(); llama_backend_free();
fprintf(stderr, "\n\n");
return 0; return 0;
} }

View file

@ -701,11 +701,11 @@ struct gguf_file_saver {
size_t info_offset; size_t info_offset;
size_t tensor_offset = 0; size_t tensor_offset = 0;
gguf_file_saver(const char * fname, gguf_file_loader * fl, enum llama_ftype new_ftype) gguf_file_saver(const char * fname, gguf_file_loader * fl)
: file(fname, "wb"), fl(fl) { : file(fname, "wb"), fl(fl) {
fprintf(stderr, "llama.cpp: saving model to %s\n", fname); fprintf(stderr, "llama.cpp: saving model to %s\n", fname);
write_header(); write_header();
write_hparams(new_ftype); write_kv();
} }
void write_header() { void write_header() {
@ -744,75 +744,38 @@ struct gguf_file_saver {
file.write_arr<float>(key, type, data); file.write_arr<float>(key, type, data);
} }
void write_hparams(enum llama_ftype new_ftype) { // re-write the key-value section from the loaded file
void write_kv() {
const int32_t n_kv = gguf_get_n_kv(fl->gguf_ctx); const int32_t n_kv = gguf_get_n_kv(fl->gguf_ctx);
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(fl->gguf_ctx, i); const char * key = gguf_get_key(fl->gguf_ctx, i);
if (strcmp(key, "general.quantization_version") == 0) { if (strcmp(key, "general.quantization_version") == 0) {
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, new_ftype); file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION);
} else { } else {
const gguf_type vtype = gguf_get_kv_type(fl->gguf_ctx, i); const gguf_type vtype = gguf_get_kv_type(fl->gguf_ctx, i);
bool bool_val;
float f32_val;
int16_t i16_val;
int32_t i32_val;
int8_t i8_val;
std::string str_val;
uint16_t u16_val;
uint32_t u32_val;
uint8_t u8_val;
gguf_type arr_type;
int n_arr;
switch (vtype) { switch (vtype) {
case GGUF_TYPE_BOOL: case GGUF_TYPE_BOOL: file.write_val<bool> (key, GGUF_TYPE_BOOL, gguf_get_val_bool(fl->gguf_ctx, i)); break;
bool_val = gguf_get_val_bool(fl->gguf_ctx, i); case GGUF_TYPE_FLOAT32: file.write_val<float> (key, GGUF_TYPE_FLOAT32, gguf_get_val_f32 (fl->gguf_ctx, i)); break;
file.write_val<bool>(key, GGUF_TYPE_BOOL, bool_val); case GGUF_TYPE_INT16: file.write_val<int16_t> (key, GGUF_TYPE_INT16, gguf_get_val_i16 (fl->gguf_ctx, i)); break;
break; case GGUF_TYPE_INT32: file.write_val<int32_t> (key, GGUF_TYPE_INT32, gguf_get_val_i32 (fl->gguf_ctx, i)); break;
case GGUF_TYPE_FLOAT32: case GGUF_TYPE_INT8: file.write_val<int8_t> (key, GGUF_TYPE_INT8, gguf_get_val_i8 (fl->gguf_ctx, i)); break;
f32_val = gguf_get_val_f32(fl->gguf_ctx, i); case GGUF_TYPE_STRING: file.write_str (key, GGUF_TYPE_STRING, gguf_get_val_str (fl->gguf_ctx, i)); break;
file.write_val<float>(key, GGUF_TYPE_FLOAT32, f32_val); case GGUF_TYPE_UINT16: file.write_val<uint16_t>(key, GGUF_TYPE_UINT16, gguf_get_val_u16 (fl->gguf_ctx, i)); break;
break; case GGUF_TYPE_UINT32: file.write_val<uint32_t>(key, GGUF_TYPE_UINT32, gguf_get_val_u32 (fl->gguf_ctx, i)); break;
case GGUF_TYPE_INT16: case GGUF_TYPE_UINT8: file.write_val<uint8_t> (key, GGUF_TYPE_UINT8, gguf_get_val_u8 (fl->gguf_ctx, i)); break;
i16_val = gguf_get_val_i16(fl->gguf_ctx, i);
file.write_val<int16_t>(key, GGUF_TYPE_INT16, i16_val);
break;
case GGUF_TYPE_INT32:
i32_val = gguf_get_val_i32(fl->gguf_ctx, i);
file.write_val<int32_t>(key, GGUF_TYPE_INT32, i32_val);
break;
case GGUF_TYPE_INT8:
i8_val = gguf_get_val_i8(fl->gguf_ctx, i);
file.write_val<int8_t>(key, GGUF_TYPE_INT8, i8_val);
break;
case GGUF_TYPE_STRING:
str_val = gguf_get_val_str(fl->gguf_ctx, i);
file.write_str(key, GGUF_TYPE_STRING, str_val);
break;
case GGUF_TYPE_UINT16:
u16_val = gguf_get_val_u16(fl->gguf_ctx, i);
file.write_val<uint16_t>(key, GGUF_TYPE_UINT16, u16_val);
break;
case GGUF_TYPE_UINT32:
u32_val = gguf_get_val_u32(fl->gguf_ctx, i);
file.write_val<uint32_t>(key, GGUF_TYPE_UINT32, u32_val);
break;
case GGUF_TYPE_UINT8:
u8_val = gguf_get_val_u8(fl->gguf_ctx, i);
file.write_val<uint8_t>(key, GGUF_TYPE_UINT8, u8_val);
break;
case GGUF_TYPE_ARRAY: case GGUF_TYPE_ARRAY:
arr_type = gguf_get_arr_type(fl->gguf_ctx, i); {
n_arr = gguf_get_arr_n (fl->gguf_ctx, i); const gguf_type arr_type = gguf_get_arr_type(fl->gguf_ctx, i);
if (arr_type == GGUF_TYPE_FLOAT32) { const int n_arr = gguf_get_arr_n (fl->gguf_ctx, i);
write_hparam_arr_f32(key, arr_type, i, n_arr); if (arr_type == GGUF_TYPE_FLOAT32) {
} else if (arr_type == GGUF_TYPE_STRING) { write_hparam_arr_f32(key, arr_type, i, n_arr);
write_hparam_arr_str(key, GGUF_TYPE_STRING, i, n_arr); } else if (arr_type == GGUF_TYPE_STRING) {
} else { write_hparam_arr_str(key, arr_type, i, n_arr);
throw std::runtime_error("not implemented"); } else {
} throw std::runtime_error("not implemented");
break; }
} break;
default: default:
throw std::runtime_error(format("cannot recognize value type for key %s\n", key)); throw std::runtime_error(format("cannot recognize value type for key %s\n", key));
} }
@ -3264,7 +3227,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} }
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false)); std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false));
gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype); gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get());
#ifdef GGML_USE_K_QUANTS #ifdef GGML_USE_K_QUANTS
int n_attention_wv = 0; int n_attention_wv = 0;