llama2.c: use defines for gguf keys

This commit is contained in:
ochafik 2023-08-26 21:41:53 +01:00
parent df3b81ab29
commit 20c44711bc

View file

@ -13,7 +13,37 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
// tensor names // GGUF keys & tensor names.
#define KV_GENERAL_ARCHITECTURE "general.architecture"
#define KV_GENERAL_NAME "general.name"
#define KV_TOKENIZER_MODEL "tokenizer.ggml.model"
#define KV_TOKENIZER_LIST "tokenizer.ggml.tokens"
#define KV_TOKENIZER_TOKEN_TYPE "tokenizer.ggml.token_type"
#define KV_TOKENIZER_SCORES "tokenizer.ggml.scores"
#define KV_TOKENIZER_MERGES "tokenizer.ggml.merges"
#define KV_TOKENIZER_BOS_ID "tokenizer.ggml.bos_token_id"
#define KV_TOKENIZER_EOS_ID "tokenizer.ggml.eos_token_id"
#define KV_TOKENIZER_UNK_ID "tokenizer.ggml.unknown_token_id"
#define KV_TOKENIZER_SEP_ID "tokenizer.ggml.seperator_token_id"
#define KV_TOKENIZER_PAD_ID "tokenizer.ggml.padding_token_id"
#define KV_TOKENIZER_HF_JSON "tokenizer.huggingface.json"
#define KV_CONTEXT_LENGTH "llama.context_length"
#define KV_EMBEDDING_LENGTH "llama.embedding_length"
#define KV_BLOCK_COUNT "llama.block_count"
#define KV_FEED_FORWARD_LENGTH "llama.feed_forward_length"
#define KV_USE_PARALLEL_RESIDUAL "llama.use_parallel_residual"
#define KV_TENSOR_DATA_LAYOUT "llama.tensor_data_layout"
#define KV_ATTENTION_HEAD_COUNT "llama.attention.head_count"
#define KV_ATTENTION_HEAD_COUNT_KV "llama.attention.head_count_kv"
#define KV_ATTENTION_LAYERNORM_RMS_EPS "llama.attention.layer_norm_rms_epsilon"
#define KV_ROPE_DIMENSION_COUNT "llama.rope.dimension_count"
#define KV_ROPE_SCALE_LINEAR "llama.rope.scale_linear"
#define TN_TOKEN_EMBD "token_embd.weight" #define TN_TOKEN_EMBD "token_embd.weight"
#define TN_OUTPUT_NORM "output_norm.weight" #define TN_OUTPUT_NORM "output_norm.weight"
#define TN_OUTPUT "output.weight" #define TN_OUTPUT "output.weight"
@ -34,6 +64,7 @@
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_VERSION_GGJT_V3 3 #define LLAMA_FILE_VERSION_GGJT_V3 3
#define TOKENIZER_NAME "llama"
#define UNKNOWN_TOKEN_ID 0 #define UNKNOWN_TOKEN_ID 0
#define BOS_TOKEN_ID 1 #define BOS_TOKEN_ID 1
#define EOS_TOKEN_ID 2 #define EOS_TOKEN_ID 2
@ -233,6 +264,8 @@ struct my_llama_layer {
struct my_llama_model { struct my_llama_model {
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
std::string name;
my_llama_hparams hparams; my_llama_hparams hparams;
struct ggml_tensor * tok_embeddings; struct ggml_tensor * tok_embeddings;
@ -543,19 +576,19 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
struct gguf_context * ctx = gguf_init_from_file(filename, params); struct gguf_context * ctx = gguf_init_from_file(filename, params);
GGML_ASSERT(ctx != NULL); GGML_ASSERT(ctx != NULL);
const int model_idx = gguf_find_key(ctx, "tokenizer.ggml.model"); const int model_idx = gguf_find_key(ctx, KV_TOKENIZER_MODEL);
GGML_ASSERT(model_idx >= 0); GGML_ASSERT(model_idx >= 0);
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx); std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
GGML_ASSERT(tokenizer_name == "llama"); GGML_ASSERT(tokenizer_name == TOKENIZER_NAME);
const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens"); const int token_idx = gguf_find_key(ctx, KV_TOKENIZER_LIST);
GGML_ASSERT(token_idx >= 0); GGML_ASSERT(token_idx >= 0);
const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores"); const int score_idx = gguf_find_key(ctx, KV_TOKENIZER_SCORES);
GGML_ASSERT(score_idx >= 0); GGML_ASSERT(score_idx >= 0);
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); const int toktype_idx = gguf_find_key(ctx, KV_TOKENIZER_TOKEN_TYPE);
GGML_ASSERT(toktype_idx >= 0); GGML_ASSERT(toktype_idx >= 0);
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
@ -694,30 +727,31 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
scores.push_back(token_data.score); scores.push_back(token_data.score);
token_types.push_back(token_data.type); token_types.push_back(token_data.type);
} }
gguf_set_arr_str(ctx, "tokenizer.ggml.tokens", tokens.data(), tokens.size()); gguf_set_arr_str(ctx, KV_TOKENIZER_LIST, tokens.data(), tokens.size());
gguf_set_arr_data(ctx, "tokenizer.ggml.scores", GGUF_TYPE_FLOAT32, scores.data(), scores.size()); gguf_set_arr_data(ctx, KV_TOKENIZER_SCORES, GGUF_TYPE_FLOAT32, scores.data(), scores.size());
gguf_set_arr_data(ctx, "tokenizer.ggml.token_type", GGUF_TYPE_INT32, token_types.data(), token_types.size()); gguf_set_arr_data(ctx, KV_TOKENIZER_TOKEN_TYPE, GGUF_TYPE_INT32, token_types.data(), token_types.size());
gguf_set_val_str(ctx, "tokenizer.ggml.model", "llama"); gguf_set_val_str(ctx, KV_TOKENIZER_MODEL, TOKENIZER_NAME);
gguf_set_val_str(ctx, "general.name", "llama2.c"); gguf_set_val_str(ctx, KV_GENERAL_ARCHITECTURE, "llama");
gguf_set_val_str(ctx, "general.architecture", "llama"); gguf_set_val_str(ctx, KV_GENERAL_NAME, "llama");
// special tokens // special tokens
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.sep_token_id", -1); gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, -1);
gguf_set_val_u32(ctx, "tokenizer.ggml.pad_token_id", -1); gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, -1);
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx); gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx);
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd); gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd);
gguf_set_val_u32(ctx, "llama.feed_forward_length", model->hparams.n_ff); gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff);
gguf_set_val_u32(ctx, "llama.attention.head_count", model->hparams.n_head); gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
gguf_set_val_u32(ctx, "llama.block_count", model->hparams.n_layer); // n_head_kv is optional, default to n_head
gguf_set_val_u32(ctx, "llama.rope.dimension_count", model->hparams.n_rot); // gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, ...);
gguf_set_val_f32(ctx, "llama.attention.layer_norm_rms_epsilon", 1e-5f); gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer);
// // n_head_kv is optional, default to n_head gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot);
gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
// write tensors // write tensors
ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD); ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD);
@ -880,6 +914,14 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
return true; return true;
} }
std::string basename(const std::string &path) {
size_t pos = path.find_last_of("/");
if (pos == std::string::npos) {
return path;
}
return path.substr(pos + 1);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
struct train_params params = get_default_train_params(); struct train_params params = get_default_train_params();
if (!params_parse(argc, argv, &params)) { if (!params_parse(argc, argv, &params)) {
@ -922,6 +964,7 @@ int main(int argc, char ** argv) {
model.ctx = ggml_init(lcparams); model.ctx = ggml_init(lcparams);
init_model(&model); init_model(&model);
model.name = basename(params.fn_llama2c_model);
save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model); printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);