llama : simplify gguf_file_loader
This commit is contained in:
parent
2906d5492d
commit
6c3f824697
1 changed files with 14 additions and 27 deletions
|
@ -590,24 +590,6 @@ struct gguf_file_loader {
|
||||||
read_tensor_metadata(tensors_map);
|
read_tensor_metadata(tensors_map);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t read_u32(const char * key) const {
|
|
||||||
int i = gguf_find_key(gguf_ctx, key);
|
|
||||||
if (i == -1) {
|
|
||||||
throw std::runtime_error(format("cannot find param with key %s\n", key));
|
|
||||||
}
|
|
||||||
|
|
||||||
return gguf_get_val_u32(gguf_ctx, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
float read_f32(const char * key) const {
|
|
||||||
int i = gguf_find_key(gguf_ctx, key);
|
|
||||||
if (i == -1) {
|
|
||||||
throw std::runtime_error(format("cannot find param with key %s\n", key));
|
|
||||||
}
|
|
||||||
|
|
||||||
return gguf_get_val_f32(gguf_ctx, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
int read_n_vocab() const {
|
int read_n_vocab() const {
|
||||||
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
|
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
|
||||||
if (i == -1) {
|
if (i == -1) {
|
||||||
|
@ -622,17 +604,22 @@ struct gguf_file_loader {
|
||||||
// TODO: read all hparams from file
|
// TODO: read all hparams from file
|
||||||
|
|
||||||
hparams.n_vocab = read_n_vocab();
|
hparams.n_vocab = read_n_vocab();
|
||||||
hparams.n_ctx = read_u32("llama.context_length");
|
hparams.n_ctx = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.context_length"));
|
||||||
hparams.n_embd = read_u32("llama.embedding_length");
|
hparams.n_embd = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.embedding_length"));
|
||||||
hparams.n_ff = read_u32("llama.feed_forward_length");
|
hparams.n_ff = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.feed_forward_length"));
|
||||||
hparams.n_head = read_u32("llama.attention.head_count");
|
hparams.n_head = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.attention.head_count"));
|
||||||
hparams.n_layer = read_u32("llama.block_count");
|
hparams.n_layer = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.block_count"));
|
||||||
hparams.n_rot = read_u32("llama.rope.dimension_count");
|
hparams.n_rot = gguf_get_val_u32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rope.dimension_count"));
|
||||||
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon");
|
hparams.f_rms_norm_eps = gguf_get_val_f32(gguf_ctx, gguf_find_key(gguf_ctx, "llama.rms_norm_epsilon"));
|
||||||
|
|
||||||
// n_head_kv default to n_head
|
// n_head_kv default to n_head
|
||||||
hparams.n_head_kv = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv") == -1 ? hparams.n_head : read_u32("llama.attention.head_count_kv");
|
hparams.n_head_kv = hparams.n_head;
|
||||||
|
{
|
||||||
|
const int idx = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv");
|
||||||
|
if (idx >= 0) {
|
||||||
|
hparams.n_head_kv = gguf_get_val_u32(gguf_ctx, idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_vocab() {
|
void read_vocab() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue