llm : read arch-specific KVs
This commit is contained in:
parent
b19c6e4640
commit
9f28f73785
1 changed files with 33 additions and 8 deletions
33
llama.cpp
33
llama.cpp
|
@ -344,6 +344,7 @@ struct LLM_TN {
|
||||||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
|
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// gguf helpers
|
// gguf helpers
|
||||||
//
|
//
|
||||||
|
@ -1497,12 +1498,12 @@ static void llm_load_hparams(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
float rope_freq_base,
|
float rope_freq_base,
|
||||||
float rope_freq_scale) {
|
float rope_freq_scale) {
|
||||||
auto & hparams = model.hparams;
|
|
||||||
|
|
||||||
struct gguf_context * ctx = ml.ctx_gguf;
|
struct gguf_context * ctx = ml.ctx_gguf;
|
||||||
|
|
||||||
const auto kv = LLM_KV(arch);
|
const auto kv = LLM_KV(arch);
|
||||||
|
|
||||||
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
// get general kv
|
// get general kv
|
||||||
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
|
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
|
||||||
GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
|
GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
|
||||||
|
@ -1514,8 +1515,6 @@ static void llm_load_hparams(
|
||||||
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
||||||
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
||||||
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
||||||
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
|
||||||
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
|
||||||
|
|
||||||
// n_head_kv is optional, default to n_head
|
// n_head_kv is optional, default to n_head
|
||||||
hparams.n_head_kv = hparams.n_head;
|
hparams.n_head_kv = hparams.n_head;
|
||||||
|
@ -1523,11 +1522,36 @@ static void llm_load_hparams(
|
||||||
|
|
||||||
// TODO: manually setting rope scale should override this
|
// TODO: manually setting rope scale should override this
|
||||||
// rope_freq_scale (inverse of the kv) is optional
|
// rope_freq_scale (inverse of the kv) is optional
|
||||||
|
{
|
||||||
float ropescale = 1.0f;
|
float ropescale = 1.0f;
|
||||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||||
if (ropescale != 1.0f) {
|
if (ropescale != 1.0f) {
|
||||||
rope_freq_scale = 1.0f/ropescale;
|
rope_freq_scale = 1.0f/ropescale;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanity check for n_rot (optional)
|
||||||
|
{
|
||||||
|
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
||||||
|
|
||||||
|
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||||
|
|
||||||
|
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
|
||||||
|
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// arch-specific KVs
|
||||||
|
switch (arch) {
|
||||||
|
case LLM_ARCH_LLAMA:
|
||||||
|
{
|
||||||
|
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_FALCON:
|
||||||
|
{
|
||||||
|
} break;
|
||||||
|
default: (void)0;
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: generalize to non-LLaMA models
|
// TODO: generalize to non-LLaMA models
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
|
@ -1594,6 +1618,7 @@ static void llm_load_vocab(
|
||||||
// determine vocab type
|
// determine vocab type
|
||||||
{
|
{
|
||||||
std::string tokenizer_name;
|
std::string tokenizer_name;
|
||||||
|
|
||||||
GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
|
GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
|
||||||
|
|
||||||
if (tokenizer_name == "llama") {
|
if (tokenizer_name == "llama") {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue