llama : add arch member to llama_model

This commit is contained in:
Georgi Gerganov 2023-08-22 22:09:56 +03:00
parent 5c5413dc14
commit 085228e1f5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -938,10 +938,10 @@ struct llama_vocab {
struct llama_model {
e_model type = MODEL_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a";
std::string arch = "n/a";
llama_hparams hparams;
llama_vocab vocab;
@ -1481,7 +1481,7 @@ static const char * llama_model_type_name(e_model type) {
}
}
static llm_arch llm_load_arch(llama_model_loader & ml) {
static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
@ -1489,16 +1489,13 @@ static llm_arch llm_load_arch(llama_model_loader & ml) {
std::string arch_name;
GGUF_GET_KEY(ctx, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_GENERAL_ARCHITECTURE));
const llm_arch arch = llm_arch_from_string(arch_name);
if (arch == LLM_ARCH_UNKNOWN) {
model.arch = llm_arch_from_string(arch_name);
if (model.arch == LLM_ARCH_UNKNOWN) {
throw std::runtime_error("unknown model architecture: '" + arch_name + "'");
}
return arch;
}
static void llm_load_hparams(
llm_arch arch,
llama_model_loader & ml,
llama_model & model,
int n_ctx,
@ -1506,13 +1503,12 @@ static void llm_load_hparams(
float rope_freq_scale) {
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(arch);
const auto kv = LLM_KV(model.arch);
auto & hparams = model.hparams;
// get general kv
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
// get hparams kv
GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
@ -1548,7 +1544,7 @@ static void llm_load_hparams(
}
// arch-specific KVs
switch (arch) {
switch (model.arch) {
case LLM_ARCH_LLAMA:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
@ -1593,14 +1589,13 @@ static void llm_load_hparams(
}
static void llm_load_vocab(
llm_arch arch,
llama_model_loader & ml,
llama_model & model) {
auto & vocab = model.vocab;
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(arch);
const auto kv = LLM_KV(model.arch);
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
if (token_idx == -1) {
@ -1672,7 +1667,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, model.arch.c_str());
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
@ -1704,7 +1699,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
}
static void llm_load_tensors(
llm_arch arch,
llama_model_loader & ml,
llama_model & model,
int n_batch,
@ -1776,9 +1770,9 @@ static void llm_load_tensors(
const int64_t n_layer = hparams.n_layer;
const int64_t n_vocab = hparams.n_vocab;
const auto tn = LLM_TN(arch);
const auto tn = LLM_TN(model.arch);
switch (arch) {
switch (model.arch) {
case LLM_ARCH_LLAMA:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
@ -1993,10 +1987,9 @@ static bool llama_model_load(
try {
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
const llm_arch arch = llm_load_arch(*ml);
llm_load_hparams(arch, *ml, model, n_ctx, rope_freq_base, rope_freq_scale);
llm_load_vocab (arch, *ml, model);
llm_load_arch (*ml, model);
llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale);
llm_load_vocab (*ml, model);
llm_load_print_meta(*ml, model);
@ -2010,7 +2003,7 @@ static bool llama_model_load(
}
llm_load_tensors(
arch, *ml, model, n_batch, n_gpu_layers,
*ml, model, n_batch, n_gpu_layers,
main_gpu, tensor_split, mul_mat_q, low_vram, memory_type,
use_mlock, progress_callback, progress_callback_user_data);
} catch (const std::exception & err) {