rename n_ctx to kv_size
This commit is contained in:
parent
ef96e8b1f7
commit
606873401c
48 changed files with 403 additions and 393 deletions
|
@ -22,7 +22,7 @@
|
|||
|
||||
struct my_llama_hparams {
|
||||
uint32_t n_vocab = 32000;
|
||||
uint32_t n_ctx = 512;
|
||||
uint32_t kv_size = 512;
|
||||
uint32_t n_embd = 4096;
|
||||
uint32_t n_head = 32;
|
||||
uint32_t n_layer = 32;
|
||||
|
@ -112,7 +112,7 @@ static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
|
|||
|
||||
static void print_params(struct my_llama_hparams * params) {
|
||||
printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
|
||||
printf("%s: n_ctx: %u\n", __func__, params->n_ctx);
|
||||
printf("%s: kv_size: %u\n", __func__, params->kv_size);
|
||||
printf("%s: n_embd: %u\n", __func__, params->n_embd);
|
||||
printf("%s: n_head: %u\n", __func__, params->n_head);
|
||||
printf("%s: n_ff: %u\n", __func__, params->n_ff);
|
||||
|
@ -272,7 +272,7 @@ static struct ggml_tensor * llama_build_train_graphs(
|
|||
const int n_past = 0;
|
||||
const int N = n_tokens;
|
||||
const auto & hparams = model->hparams;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int kv_size = hparams.kv_size;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
|
@ -295,13 +295,13 @@ static struct ggml_tensor * llama_build_train_graphs(
|
|||
ggml_set_input(KQ_pos);
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
auto rope = [ctx, KQ_pos, n_rot, kv_size, rope_freq_base, rope_freq_scale]
|
||||
(struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
// not capturing these, to silcence warnings
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(
|
||||
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
ctx, t, KQ_pos, n_rot, rope_mode, kv_size, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -487,8 +487,8 @@ static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_contex
|
|||
GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
|
||||
GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
|
||||
|
||||
// n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
|
||||
// kv_size was not saved in earlier checkpoint file versions, so we make it optional here
|
||||
GGUF_GET_KEY(fctx, model->hparams.kv_size, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
|
||||
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
||||
|
@ -543,7 +543,7 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
|
|||
gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
|
||||
|
||||
// set hparams
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.kv_size );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head );
|
||||
|
@ -945,7 +945,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
struct my_llama_model model;
|
||||
model.hparams.n_vocab = llama_n_vocab(lmodel);
|
||||
model.hparams.n_ctx = params.common.n_ctx;
|
||||
model.hparams.kv_size = params.common.n_ctx;
|
||||
model.hparams.n_embd = params.n_embd;
|
||||
model.hparams.n_head = params.n_head;
|
||||
model.hparams.n_layer = params.n_layer;
|
||||
|
@ -982,9 +982,9 @@ int main(int argc, char ** argv) {
|
|||
printf("%s: init model\n", __func__);
|
||||
bool existed = load_checkpoint_file(params.common.fn_checkpoint_in, &model, train);
|
||||
if (existed) {
|
||||
// overwrite last n_ctx with user provided n_ctx
|
||||
// overwrite last kv_size with user provided kv_size
|
||||
if (params.common.custom_n_ctx) {
|
||||
model.hparams.n_ctx = params.common.n_ctx;
|
||||
model.hparams.kv_size = params.common.n_ctx;
|
||||
}
|
||||
|
||||
const bool opt_past_changed = opt->params.past != params.common.opt_past;
|
||||
|
@ -1031,7 +1031,7 @@ int main(int argc, char ** argv) {
|
|||
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
|
||||
printf("%s: opt iter %d\n", __func__, opt->iter);
|
||||
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_tokens = model.hparams.kv_size;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
int n_batch = params.common.n_batch;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue