rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -16,7 +16,7 @@
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512;
uint32_t kv_size = 512;
uint32_t n_embd = 4096;
uint32_t n_ff = 11008;
uint32_t n_head = 32;
@ -190,7 +190,7 @@ static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
static void print_params(struct my_llama_hparams * params) {
printf("%s: n_vocab : %u\n", __func__, params->n_vocab);
printf("%s: n_ctx : %u\n", __func__, params->n_ctx);
printf("%s: kv_size : %u\n", __func__, params->kv_size);
printf("%s: n_embd : %u\n", __func__, params->n_embd);
printf("%s: n_ff : %u\n", __func__, params->n_ff);
printf("%s: n_head : %u\n", __func__, params->n_head);
@ -250,7 +250,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
};
GGUF_GET_KEY(ctx, hparams->n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
GGUF_GET_KEY(ctx, hparams->n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
GGUF_GET_KEY(ctx, hparams->kv_size, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
GGUF_GET_KEY(ctx, hparams->n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams->n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams->n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
@ -268,7 +268,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
}
}
static void init_model(struct llama_model * input, struct my_llama_model * model, const char * fn_model, uint32_t n_ctx) {
static void init_model(struct llama_model * input, struct my_llama_model * model, const char * fn_model, uint32_t kv_size) {
auto & hparams = model->hparams;
std::vector<char> tn_buf;
@ -298,7 +298,7 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
gguf_free(mctx);
}
hparams.n_vocab = llama_n_vocab(input);
hparams.n_ctx = n_ctx;
hparams.kv_size = kv_size;
// get tensors from llama_model (possibly mmapped)
model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
@ -529,7 +529,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_past = 0;
const int N = n_tokens;
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int kv_size = hparams.kv_size;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
@ -558,13 +558,13 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
auto rope = [ctx, KQ_pos, n_rot, kv_size, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
const int rope_mode = 0;
return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
t, KQ_pos, n_rot, rope_mode, kv_size, 0,
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};
@ -848,7 +848,7 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx);
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.kv_size);
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd);
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff);
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
@ -1554,9 +1554,9 @@ int main(int argc, char ** argv) {
bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
if (existed) {
// overwrite last n_ctx with user provided n_ctx
// overwrite last kv_size with user provided kv_size
if (params.common.custom_n_ctx) {
model.hparams.n_ctx = params.common.n_ctx;
model.hparams.kv_size = params.common.n_ctx;
}
const bool opt_param_count_changed = (
@ -1625,7 +1625,7 @@ int main(int argc, char ** argv) {
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
printf("%s: opt iter %d\n", __func__, opt->iter);
int n_tokens = model.hparams.n_ctx;
int n_tokens = model.hparams.kv_size;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;