From 6c3f824697530deb5fdf4399f7aadcd2f9139283 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Aug 2023 20:53:53 +0300 Subject: [PATCH] llama : simplify gguf_file_loader --- gguf-llama.cpp | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/gguf-llama.cpp b/gguf-llama.cpp index 9bf6cc9a9..895fc98bd 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -590,24 +590,6 @@ struct gguf_file_loader { read_tensor_metadata(tensors_map); } - uint32_t read_u32(const char * key) const { - int i = gguf_find_key(gguf_ctx, key); - if (i == -1) { - throw std::runtime_error(format("cannot find param with key %s\n", key)); - } - - return gguf_get_val_u32(gguf_ctx, i); - } - - float read_f32(const char * key) const { - int i = gguf_find_key(gguf_ctx, key); - if (i == -1) { - throw std::runtime_error(format("cannot find param with key %s\n", key)); - } - - return gguf_get_val_f32(gguf_ctx, i); - } - int read_n_vocab() const { int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens"); if (i == -1) { @@ -622,17 +604,22 @@ struct gguf_file_loader { // TODO: read all hparams from file hparams.n_vocab = read_n_vocab(); - hparams.n_ctx = read_u32("llama.context_length"); - hparams.n_embd = read_u32("llama.embedding_length"); - hparams.n_ff = read_u32("llama.feed_forward_length"); - hparams.n_head = read_u32("llama.attention.head_count"); - hparams.n_layer = read_u32("llama.block_count"); - hparams.n_rot = read_u32("llama.rope.dimension_count"); - hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon"); + hparams.n_ctx = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.context_length")); + hparams.n_embd = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.embedding_length")); + hparams.n_ff = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.feed_forward_length")); + hparams.n_head = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.attention.head_count")); + hparams.n_layer = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.block_count")); + hparams.n_rot = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rope.dimension_count")); + hparams.f_rms_norm_eps = gguf_get_val_f32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rms_norm_epsilon")); // n_head_kv default to n_head - hparams.n_head_kv = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv") == -1 ? hparams.n_head : read_u32("llama.attention.head_count_kv"); - + hparams.n_head_kv = hparams.n_head; + { + const int idx = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv"); + if (idx >= 0) { + hparams.n_head_kv = gguf_get_val_u32(gguf_ctx, idx); + } + } } void read_vocab() {