From b409fd8e117f5576c4942485d243ef41570ea56b Mon Sep 17 00:00:00 2001 From: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Date: Mon, 13 May 2024 13:32:41 +0200 Subject: [PATCH] Add remaining time mix parameters --- src/llama.cpp | 79 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 174177775..287a36520 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -520,10 +520,17 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, - LLM_TENSOR_TIME_MIX_K, - LLM_TENSOR_TIME_MIX_V, - LLM_TENSOR_TIME_MIX_R, - LLM_TENSOR_TIME_MIX_G, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, @@ -1348,16 +1355,23 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_RWKV, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_TIME_MIX_K, "blk.%d.time_mix_k" }, - { LLM_TENSOR_TIME_MIX_V, "blk.%d.time_mix_v" }, - { LLM_TENSOR_TIME_MIX_R, "blk.%d.time_mix_r" }, - { LLM_TENSOR_TIME_MIX_G, "blk.%d.time_mix_g" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix.lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix.lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix.lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix.lerp_g" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix.first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix.decay" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix.key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix.value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix.receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix.gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix.ln" }, }, }, { @@ -2523,10 +2537,20 @@ struct llama_layer { struct ggml_tensor * ssm_dt_b; // rwkv - struct ggml_tensor * time_mix_k; - struct ggml_tensor * time_mix_v; - struct ggml_tensor * time_mix_r; - struct ggml_tensor * time_mix_g; + struct ggml_tensor * time_mix_lerp_k; + struct ggml_tensor * time_mix_lerp_v; + struct ggml_tensor * time_mix_lerp_r; + struct ggml_tensor * time_mix_lerp_g; + + struct ggml_tensor * time_mix_first; + struct ggml_tensor * time_mix_decay; + struct ggml_tensor * time_mix_key; + struct ggml_tensor * time_mix_value; + struct ggml_tensor * time_mix_receptance; + struct ggml_tensor * time_mix_gate; + + struct ggml_tensor * time_mix_ln; + struct ggml_tensor * time_mix_ln_b; // long rope factors struct ggml_tensor * rope_long = nullptr; @@ -8274,10 +8298,21 @@ static bool llm_load_tensors( layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}); - layer.time_mix_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_K, "weight", i), {n_embd, 1, 1}); - layer.time_mix_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_V, "weight", i), {n_embd, 1, 1}); - layer.time_mix_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_R, "weight", i), {n_embd, 1, 1}); - layer.time_mix_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_G, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}); + + // TODO: Parametrize hardcoded dimensions for first & decay + layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {64, 32}); + layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {64, 32}); + layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, n_embd}); + layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, n_embd}); + layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}); + layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {n_embd, n_embd}); + + layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}); + layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}); } }