llm : read arch-specific KVs
This commit is contained in:
parent
b19c6e4640
commit
9f28f73785
1 changed files with 33 additions and 8 deletions
41
llama.cpp
41
llama.cpp
|
@ -344,6 +344,7 @@ struct LLM_TN {
|
|||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// gguf helpers
|
||||
//
|
||||
|
@ -1497,12 +1498,12 @@ static void llm_load_hparams(
|
|||
int n_ctx,
|
||||
float rope_freq_base,
|
||||
float rope_freq_scale) {
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
struct gguf_context * ctx = ml.ctx_gguf;
|
||||
|
||||
const auto kv = LLM_KV(arch);
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
// get general kv
|
||||
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
|
||||
GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
|
||||
|
@ -1514,8 +1515,6 @@ static void llm_load_hparams(
|
|||
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
||||
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
||||
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
||||
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
|
||||
// n_head_kv is optional, default to n_head
|
||||
hparams.n_head_kv = hparams.n_head;
|
||||
|
@ -1523,12 +1522,37 @@ static void llm_load_hparams(
|
|||
|
||||
// TODO: manually setting rope scale should override this
|
||||
// rope_freq_scale (inverse of the kv) is optional
|
||||
float ropescale = 1.0f;
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
if (ropescale != 1.0f) {
|
||||
rope_freq_scale = 1.0f/ropescale;
|
||||
{
|
||||
float ropescale = 1.0f;
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
if (ropescale != 1.0f) {
|
||||
rope_freq_scale = 1.0f/ropescale;
|
||||
}
|
||||
}
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
{
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
||||
|
||||
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||
|
||||
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
|
||||
}
|
||||
}
|
||||
|
||||
// arch-specific KVs
|
||||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
} break;
|
||||
case LLM_ARCH_FALCON:
|
||||
{
|
||||
} break;
|
||||
default: (void)0;
|
||||
};
|
||||
|
||||
// TODO: generalize to non-LLaMA models
|
||||
switch (hparams.n_layer) {
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
|
@ -1594,6 +1618,7 @@ static void llm_load_vocab(
|
|||
// determine vocab type
|
||||
{
|
||||
std::string tokenizer_name;
|
||||
|
||||
GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
|
||||
|
||||
if (tokenizer_name == "llama") {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue