From f6be4dc6615dbff59f2561618e34b33af7093b40 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Mon, 27 Jan 2025 17:45:43 +0800 Subject: [PATCH] Add support for ARWKV7 Hybrid models Signed-off-by: Molly Sophia --- convert_hf_to_gguf.py | 78 ++++++++++++++++++ gguf-py/gguf/constants.py | 33 ++++++++ gguf-py/gguf/tensor_mapping.py | 141 +++++++++++++++++++-------------- src/llama-arch.cpp | 35 ++++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 66 ++++++++++++++- src/llama.cpp | 138 ++++++++++++++++++++++++++++++-- 7 files changed, 420 insertions(+), 72 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c427263d7..e00e96997 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3555,6 +3555,84 @@ class Rwkv7Model(Rwkv6Model): yield (new_name, data_torch) +@Model.register("RwkvHybridForCausalLM") +class ARwkv7Model(Model): + model_arch = gguf.MODEL_ARCH.ARWKV7 + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + hidden_size = self.hparams["hidden_size"] + head_size = self.hparams["head_size"] + rms_norm_eps = self.hparams["rms_norm_eps"] + intermediate_size = self.hparams["intermediate_size"] + wkv_has_gate = self.hparams["wkv_has_gate"] + assert self.hparams["wkv_version"] == 7 + + # ICLR: In-Context-Learning-Rate + lora_rank_decay = 64 + lora_rank_iclr = 64 + lora_rank_value_residual_mix = 32 + lora_rank_gate = 128 if wkv_has_gate else 0 + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_decay_lora_rank(lora_rank_decay) + self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr) + self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix) + self.gguf_writer.add_gate_lora_rank(lora_rank_gate) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_token_shift_count(1) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + lerp_weights: dict[int, dict[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if bid is not None and "self_attn.time_mixer.x_" in name: + try: + self.lerp_weights[bid][name] = data_torch + except KeyError: + self.lerp_weights[bid] = {name: data_torch} + if all(f"model.layers.{bid}.self_attn.time_mixer.x_{i}" in self.lerp_weights[bid].keys() for i in ["r", "w", "k", "v", "a", "g"]): + new_name = f"blk.{bid}.time_mix_lerp_fused.weight" + data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.self_attn.time_mixer.x_{i}"].squeeze(0) for i in ["r", "w", "k", "v", "a", "g"]], dim=0) + yield (new_name, data) + return + else: + data_torch = data_torch.squeeze() + new_name = self.map_tensor_name(name) + + if not (new_name.endswith(".weight") or new_name.endswith(".bias")): + new_name += ".weight" + + if any( + new_name.endswith(t) for t in [ + "time_mix_w1.weight", "time_mix_w2.weight", + "time_mix_a1.weight", "time_mix_a2.weight", + "time_mix_v1.weight", "time_mix_v2.weight", + "time_mix_g1.weight", "time_mix_g2.weight", + ] + ): + data_torch = data_torch.transpose(0, 1) + + if 'r_k' in new_name: + data_torch = data_torch.flatten() + + yield (new_name, data_torch) + + @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0adfe40b3..996fd2d9b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -261,6 +261,7 @@ class MODEL_ARCH(IntEnum): RWKV6 = auto() RWKV6QWEN2 = auto() RWKV7 = auto() + ARWKV7 = auto() MAMBA = auto() XVERSE = auto() COMMAND_R = auto() @@ -461,6 +462,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", @@ -1214,6 +1216,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.CHANNEL_MIX_KEY, MODEL_TENSOR.CHANNEL_MIX_VALUE, ], + MODEL_ARCH.ARWKV7: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_W0, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_A0, + MODEL_TENSOR.TIME_MIX_A1, + MODEL_TENSOR.TIME_MIX_A2, + MODEL_TENSOR.TIME_MIX_V0, + MODEL_TENSOR.TIME_MIX_V1, + MODEL_TENSOR.TIME_MIX_V2, + MODEL_TENSOR.TIME_MIX_G1, + MODEL_TENSOR.TIME_MIX_G2, + MODEL_TENSOR.TIME_MIX_K_K, + MODEL_TENSOR.TIME_MIX_K_A, + MODEL_TENSOR.TIME_MIX_R_K, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.MAMBA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a3c56e780..77dc62256 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -27,8 +27,8 @@ class TensorNameMap: "embedding.word_embeddings", # chatglm "transformer.token_embeddings", # openelm "shared", # t5 - "rwkv.embeddings", # rwkv v6 - "model.embeddings", # rwkv v7 + "rwkv.embeddings", # rwkv6 + "model.embeddings", # rwkv7 ), # Token type embeddings @@ -42,8 +42,8 @@ class TensorNameMap: "embeddings.LayerNorm", # bert "emb_ln", # nomic-bert "transformer.norm", # openelm - "rwkv.blocks.0.pre_ln", # rwkv v6 - "model.pre_ln", # rwkv v7 + "rwkv.blocks.0.pre_ln", # rwkv6 + "model.pre_ln", # rwkv7 "backbone.norm", # wavtokenizer ), @@ -83,8 +83,8 @@ class TensorNameMap: "encoder.final_layernorm", # chatglm "transformer.norm", # openelm "model.norm", # nemotron - "rwkv.ln_out", # rwkv v6 - "model.ln_out", # rwkv v7 + "rwkv.ln_out", # rwkv6 + "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer ), @@ -125,16 +125,16 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "encoder.layers.{bid}.input_layernorm", # chatglm "transformer.layers.{bid}.attn_norm", # openelm - "rwkv.blocks.{bid}.ln1", # rwkv v6 - "model.blocks.{bid}.ln1", # rwkv v7 + "rwkv.blocks.{bid}.ln1", # rwkv6 + "model.blocks.{bid}.ln1", # rwkv7 ), # Attention norm 2 MODEL_TENSOR.ATTN_NORM_2: ( "transformer.h.{bid}.ln_attn", # falcon40b "encoder.layer.{bid}.layer_norm_1", # jina-v2-code - "rwkv.blocks.{bid}.ln2", # rwkv v6 - "model.blocks.{bid}.ln2", # rwkv v7 + "rwkv.blocks.{bid}.ln2", # rwkv6 + "model.blocks.{bid}.ln2", # rwkv7 ), # Attention query-key-value @@ -468,160 +468,179 @@ class TensorNameMap: ), MODEL_TENSOR.TIME_MIX_W0: ( - "model.blocks.{bid}.attention.w0", # rwkv7 + "model.blocks.{bid}.attention.w0", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.w0", # arwkv7 ), MODEL_TENSOR.TIME_MIX_W1: ( - "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 - "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 - "model.blocks.{bid}.attention.w1" # rwkv7 + "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 + "model.blocks.{bid}.attention.w1", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.w1", # arwkv7 ), MODEL_TENSOR.TIME_MIX_W2: ( - "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 - "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 - "model.blocks.{bid}.attention.w2" # rwkv7 + "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 + "model.blocks.{bid}.attention.w2", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.w2", # arwkv7 ), MODEL_TENSOR.TIME_MIX_A0: ( - "model.blocks.{bid}.attention.a0", # rwkv7 + "model.blocks.{bid}.attention.a0", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.a0", # arwkv7 ), MODEL_TENSOR.TIME_MIX_A1: ( - "model.blocks.{bid}.attention.a1", # rwkv7 + "model.blocks.{bid}.attention.a1", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.a1", # arwkv7 ), MODEL_TENSOR.TIME_MIX_A2: ( - "model.blocks.{bid}.attention.a2", # rwkv7 + "model.blocks.{bid}.attention.a2", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.a2", # arwkv7 ), MODEL_TENSOR.TIME_MIX_V0: ( - "model.blocks.{bid}.attention.v0", # rwkv7 + "model.blocks.{bid}.attention.v0", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.v0", # arwkv7 ), MODEL_TENSOR.TIME_MIX_V1: ( - "model.blocks.{bid}.attention.v1", # rwkv7 + "model.blocks.{bid}.attention.v1", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.v1", # arwkv7 ), MODEL_TENSOR.TIME_MIX_V2: ( - "model.blocks.{bid}.attention.v2", # rwkv7 + "model.blocks.{bid}.attention.v2", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.v2", # arwkv7 ), MODEL_TENSOR.TIME_MIX_G1: ( - "model.blocks.{bid}.attention.g1", # rwkv7 + "model.blocks.{bid}.attention.g1", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.g1", # arwkv7 ), MODEL_TENSOR.TIME_MIX_G2: ( - "model.blocks.{bid}.attention.g2", # rwkv7 + "model.blocks.{bid}.attention.g2", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.g2", # arwkv7 ), MODEL_TENSOR.TIME_MIX_K_K: ( - "model.blocks.{bid}.attention.k_k", # rwkv7 + "model.blocks.{bid}.attention.k_k", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.k_k", # arwkv7 ), MODEL_TENSOR.TIME_MIX_K_A: ( - "model.blocks.{bid}.attention.k_a", # rwkv7 + "model.blocks.{bid}.attention.k_a", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.k_a", # arwkv7 ), MODEL_TENSOR.TIME_MIX_R_K: ( - "model.blocks.{bid}.attention.r_k", # rwkv7 + "model.blocks.{bid}.attention.r_k", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.r_k", # arwkv7 ), MODEL_TENSOR.TIME_MIX_LERP_X: ( - "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6 "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_K: ( - "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6 "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_V: ( - "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6 "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_R: ( - "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6 "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_G: ( - "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6 "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_W: ( - "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6 "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_FIRST: ( - "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6 ), MODEL_TENSOR.TIME_MIX_DECAY: ( - "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_decay", # rwkv6 "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_DECAY_W1: ( - "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6 "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_DECAY_W2: ( - "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 + "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6 "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_KEY: ( - "rwkv.blocks.{bid}.attention.key", # rwkv v6 - "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 - "model.blocks.{bid}.attention.key", # rwkv v7 + "rwkv.blocks.{bid}.attention.key", # rwkv6 + "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 + "model.blocks.{bid}.attention.key", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.key.weight", # arwkv7 ), MODEL_TENSOR.TIME_MIX_VALUE: ( - "rwkv.blocks.{bid}.attention.value", # rwkv v6 - "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 - "model.blocks.{bid}.attention.value", # rwkv v7 + "rwkv.blocks.{bid}.attention.value", # rwkv6 + "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 + "model.blocks.{bid}.attention.value", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.value.weight", # arwkv7 ), MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( - "rwkv.blocks.{bid}.attention.receptance", # rwkv v6 - "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 - "model.blocks.{bid}.attention.receptance", # rwkv v7 + "rwkv.blocks.{bid}.attention.receptance", # rwkv6 + "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 + "model.blocks.{bid}.attention.receptance", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.receptance.weight", # arwkv7 ), MODEL_TENSOR.TIME_MIX_GATE: ( - "rwkv.blocks.{bid}.attention.gate", # rwkv v6 - "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 + "rwkv.blocks.{bid}.attention.gate", # rwkv6 + "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 + "model.layers.{bid}.self_attn.time_mixer.gate.weight", # arwkv7 ), MODEL_TENSOR.TIME_MIX_LN: ( - "rwkv.blocks.{bid}.attention.ln_x", # rwkv v6 - "model.blocks.{bid}.attention.ln_x" # rwkv v7 + "rwkv.blocks.{bid}.attention.ln_x", # rwkv6 + "model.blocks.{bid}.attention.ln_x" # rwkv7 ), MODEL_TENSOR.TIME_MIX_OUTPUT: ( - "rwkv.blocks.{bid}.attention.output", # rwkv - "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 - "model.blocks.{bid}.attention.output", # rwkv v7 + "rwkv.blocks.{bid}.attention.output", # rwkv + "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 + "model.blocks.{bid}.attention.output", # rwkv7 + "model.layers.{bid}.self_attn.time_mixer.output.weight", # arwkv7 ), MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( - "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6 - "model.blocks.{bid}.feed_forward.x_k", # rwkv v7 + "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6 + "model.blocks.{bid}.feed_forward.x_k", # rwkv7 ), MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( - "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6 + "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6 ), MODEL_TENSOR.CHANNEL_MIX_KEY: ( - "rwkv.blocks.{bid}.feed_forward.key", # rwkv v6 - "model.blocks.{bid}.feed_forward.key", # rwkv v7 + "rwkv.blocks.{bid}.feed_forward.key", # rwkv6 + "model.blocks.{bid}.feed_forward.key", # rwkv7 ), MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( @@ -629,8 +648,8 @@ class TensorNameMap: ), MODEL_TENSOR.CHANNEL_MIX_VALUE: ( - "rwkv.blocks.{bid}.feed_forward.value", # rwkv v6 - "model.blocks.{bid}.feed_forward.value", # rwkv v7 + "rwkv.blocks.{bid}.feed_forward.value", # rwkv6 + "model.blocks.{bid}.feed_forward.value", # rwkv7 ), MODEL_TENSOR.ATTN_Q_A: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 72eeb7b52..554f17990 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -59,6 +59,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -1256,6 +1257,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, }, }, + { + LLM_ARCH_ARWKV7, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GRANITE, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 193391e34..b4d459a50 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -63,6 +63,7 @@ enum llm_arch { LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, + LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 50fd51e12..4f28c5b59 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1211,17 +1211,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); - ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - // TODO: Add variants + case 28: type = LLM_TYPE_7B; break; // ARWKV7 default: type = LLM_TYPE_UNKNOWN; } } break; @@ -3366,6 +3368,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); } + } break; + case LLM_ARCH_ARWKV7: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_CHAMELEON: { @@ -3953,6 +4011,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: return LLAMA_ROPE_TYPE_NONE; @@ -4107,6 +4166,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) { case LLM_ARCH_RWKV6: return true; case LLM_ARCH_RWKV6QWEN2: return true; case LLM_ARCH_RWKV7: return true; + case LLM_ARCH_ARWKV7: return true; default: return false; } } diff --git a/src/llama.cpp b/src/llama.cpp index de802ce3d..41fcd4cdd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1097,7 +1097,11 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix( ); } - struct ggml_tensor * g = ggml_mul_mat(ctx, layer->time_mix_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->time_mix_g1, xg))); + struct ggml_tensor * g = nullptr; + if (layer->time_mix_g1 && layer->time_mix_g2) { + g = ggml_mul_mat(ctx, layer->time_mix_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->time_mix_g1, xg))); + } + struct ggml_tensor * a = ggml_sigmoid(ctx, ggml_add( ctx, @@ -1122,19 +1126,25 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix( cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); - // group norm with head_count groups - cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens); - cur = ggml_norm(ctx, cur, 64e-5f); + if (layer->time_mix_ln && layer->time_mix_ln_b) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens); + cur = ggml_norm(ctx, cur, 64e-5f); - // Convert back to regular vectors. - cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); - cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); + cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); + } struct ggml_tensor * rk = ggml_sum_rows(ctx, ggml_mul(ctx, ggml_mul(ctx, k, r), ggml_reshape_2d(ctx, layer->time_mix_r_k, head_size, head_count))); cur = ggml_add(ctx, cur, ggml_reshape_2d(ctx, ggml_mul(ctx, v, rk), n_embd, n_tokens)); - cur = ggml_mul(ctx, cur, g); + if (g) { + cur = ggml_mul(ctx, cur, g); + } cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); return cur; @@ -8015,6 +8025,114 @@ struct llm_build_context { return gf; } + ggml_cgraph * build_arwkv7() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + + GGML_ASSERT(n_embd == hparams.n_embd_k_s() / hparams.token_shift_count); + + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_tokens = ubatch.n_tokens; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * value_first_layer = nullptr; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, + gf, kv_self.k_l[il], state_copy, state_mask, + hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); + struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, + gf, kv_self.v_l[il], state_copy, state_mask, + hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); + + cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs); + + struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il); + struct ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0), + 1 + ); + + struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0), + ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + ) + ); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv7_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, value_first_layer, hparams.wkv_head_size)); + ggml_build_forward_expand(gf, ffn_inp); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_states, + ggml_view_1d( + ctx0, + kv_self.v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + ) + ) + ); + + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -8632,6 +8750,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_rwkv7(); } break; + case LLM_ARCH_ARWKV7: + { + result = llm.build_arwkv7(); + } break; case LLM_ARCH_CHAMELEON: { result = llm.build_chameleon();