diff --git a/src/llama.cpp b/src/llama.cpp index e0d395c61..1fd91fcd7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8364,10 +8364,9 @@ static bool llm_load_tensors( model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - // TODO: Parameterize this - const int time_mix_extra_dim = 32; - const int time_decay_extra_dim = 64; - const int head_size = 64; + const int time_mix_extra_dim = (n_embd == 4096) ? 64 : 32; + const int time_decay_extra_dim = (n_embd == 4096) ? 128 : 64; + const int head_size = hparams.wkv_head_size; const int attn_hidden_size = n_embd; const int ffn_size = (int)(n_embd * 3.5 / 32) * 32;