llama : simplify gguf_file_loader

This commit is contained in:
Georgi Gerganov 2023-08-15 20:53:53 +03:00
parent 2906d5492d
commit 6c3f824697
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -590,24 +590,6 @@ struct gguf_file_loader {
read_tensor_metadata(tensors_map); read_tensor_metadata(tensors_map);
} }
uint32_t read_u32(const char * key) const {
int i = gguf_find_key(gguf_ctx, key);
if (i == -1) {
throw std::runtime_error(format("cannot find param with key %s\n", key));
}
return gguf_get_val_u32(gguf_ctx, i);
}
float read_f32(const char * key) const {
int i = gguf_find_key(gguf_ctx, key);
if (i == -1) {
throw std::runtime_error(format("cannot find param with key %s\n", key));
}
return gguf_get_val_f32(gguf_ctx, i);
}
int read_n_vocab() const { int read_n_vocab() const {
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens"); int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
if (i == -1) { if (i == -1) {
@ -622,17 +604,22 @@ struct gguf_file_loader {
// TODO: read all hparams from file // TODO: read all hparams from file
hparams.n_vocab = read_n_vocab(); hparams.n_vocab = read_n_vocab();
hparams.n_ctx = read_u32("llama.context_length"); hparams.n_ctx = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.context_length"));
hparams.n_embd = read_u32("llama.embedding_length"); hparams.n_embd = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.embedding_length"));
hparams.n_ff = read_u32("llama.feed_forward_length"); hparams.n_ff = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.feed_forward_length"));
hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_head = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.attention.head_count"));
hparams.n_layer = read_u32("llama.block_count"); hparams.n_layer = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.block_count"));
hparams.n_rot = read_u32("llama.rope.dimension_count"); hparams.n_rot = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rope.dimension_count"));
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon"); hparams.f_rms_norm_eps = gguf_get_val_f32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rms_norm_epsilon"));
// n_head_kv default to n_head // n_head_kv default to n_head
hparams.n_head_kv = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv") == -1 ? hparams.n_head : read_u32("llama.attention.head_count_kv"); hparams.n_head_kv = hparams.n_head;
{
const int idx = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv");
if (idx >= 0) {
hparams.n_head_kv = gguf_get_val_u32(gguf_ctx, idx);
}
}
} }
void read_vocab() { void read_vocab() {