llm : read arch-specific KVs

This commit is contained in:
Georgi Gerganov 2023-08-22 20:34:17 +03:00
parent b19c6e4640
commit 9f28f73785
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -344,6 +344,7 @@ struct LLM_TN {
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
}
};
//
// gguf helpers
//
@ -1497,12 +1498,12 @@ static void llm_load_hparams(
int n_ctx,
float rope_freq_base,
float rope_freq_scale) {
auto & hparams = model.hparams;
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(arch);
auto & hparams = model.hparams;
// get general kv
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
@ -1514,8 +1515,6 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT));
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
@ -1523,12 +1522,37 @@ static void llm_load_hparams(
// TODO: manually setting rope scale should override this
// rope_freq_scale (inverse of the kv) is optional
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (ropescale != 1.0f) {
rope_freq_scale = 1.0f/ropescale;
{
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (ropescale != 1.0f) {
rope_freq_scale = 1.0f/ropescale;
}
}
// sanity check for n_rot (optional)
{
hparams.n_rot = hparams.n_embd / hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
}
}
// arch-specific KVs
switch (arch) {
case LLM_ARCH_LLAMA:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
} break;
case LLM_ARCH_FALCON:
{
} break;
default: (void)0;
};
// TODO: generalize to non-LLaMA models
switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break;
@ -1594,6 +1618,7 @@ static void llm_load_vocab(
// determine vocab type
{
std::string tokenizer_name;
GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
if (tokenizer_name == "llama") {