Add support for ARWKV7 Hybrid models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-27 17:45:43 +08:00
parent e9ba411d3e
commit f6be4dc661
7 changed files with 420 additions and 72 deletions

View file

@ -3555,6 +3555,84 @@ class Rwkv7Model(Rwkv6Model):
yield (new_name, data_torch) yield (new_name, data_torch)
@Model.register("RwkvHybridForCausalLM")
class ARwkv7Model(Model):
model_arch = gguf.MODEL_ARCH.ARWKV7
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
wkv_has_gate = self.hparams["wkv_has_gate"]
assert self.hparams["wkv_version"] == 7
# ICLR: In-Context-Learning-Rate
lora_rank_decay = 64
lora_rank_iclr = 64
lora_rank_value_residual_mix = 32
lora_rank_gate = 128 if wkv_has_gate else 0
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_token_shift_count(1)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
lerp_weights: dict[int, dict[str, Tensor]] = {}
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if bid is not None and "self_attn.time_mixer.x_" in name:
try:
self.lerp_weights[bid][name] = data_torch
except KeyError:
self.lerp_weights[bid] = {name: data_torch}
if all(f"model.layers.{bid}.self_attn.time_mixer.x_{i}" in self.lerp_weights[bid].keys() for i in ["r", "w", "k", "v", "a", "g"]):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.self_attn.time_mixer.x_{i}"].squeeze(0) for i in ["r", "w", "k", "v", "a", "g"]], dim=0)
yield (new_name, data)
return
else:
data_torch = data_torch.squeeze()
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if any(
new_name.endswith(t) for t in [
"time_mix_w1.weight", "time_mix_w2.weight",
"time_mix_a1.weight", "time_mix_a2.weight",
"time_mix_v1.weight", "time_mix_v2.weight",
"time_mix_g1.weight", "time_mix_g2.weight",
]
):
data_torch = data_torch.transpose(0, 1)
if 'r_k' in new_name:
data_torch = data_torch.flatten()
yield (new_name, data_torch)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model): class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA

View file

@ -261,6 +261,7 @@ class MODEL_ARCH(IntEnum):
RWKV6 = auto() RWKV6 = auto()
RWKV6QWEN2 = auto() RWKV6QWEN2 = auto()
RWKV7 = auto() RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto() MAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
@ -461,6 +462,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
@ -1214,6 +1216,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.CHANNEL_MIX_KEY, MODEL_TENSOR.CHANNEL_MIX_KEY,
MODEL_TENSOR.CHANNEL_MIX_VALUE, MODEL_TENSOR.CHANNEL_MIX_VALUE,
], ],
MODEL_ARCH.ARWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MAMBA: [ MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -27,8 +27,8 @@ class TensorNameMap:
"embedding.word_embeddings", # chatglm "embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm "transformer.token_embeddings", # openelm
"shared", # t5 "shared", # t5
"rwkv.embeddings", # rwkv v6 "rwkv.embeddings", # rwkv6
"model.embeddings", # rwkv v7 "model.embeddings", # rwkv7
), ),
# Token type embeddings # Token type embeddings
@ -42,8 +42,8 @@ class TensorNameMap:
"embeddings.LayerNorm", # bert "embeddings.LayerNorm", # bert
"emb_ln", # nomic-bert "emb_ln", # nomic-bert
"transformer.norm", # openelm "transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv v6 "rwkv.blocks.0.pre_ln", # rwkv6
"model.pre_ln", # rwkv v7 "model.pre_ln", # rwkv7
"backbone.norm", # wavtokenizer "backbone.norm", # wavtokenizer
), ),
@ -83,8 +83,8 @@ class TensorNameMap:
"encoder.final_layernorm", # chatglm "encoder.final_layernorm", # chatglm
"transformer.norm", # openelm "transformer.norm", # openelm
"model.norm", # nemotron "model.norm", # nemotron
"rwkv.ln_out", # rwkv v6 "rwkv.ln_out", # rwkv6
"model.ln_out", # rwkv v7 "model.ln_out", # rwkv7
"backbone.final_layer_norm", # wavtokenizer "backbone.final_layer_norm", # wavtokenizer
), ),
@ -125,16 +125,16 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm "encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm "transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv v6 "rwkv.blocks.{bid}.ln1", # rwkv6
"model.blocks.{bid}.ln1", # rwkv v7 "model.blocks.{bid}.ln1", # rwkv7
), ),
# Attention norm 2 # Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: ( MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b "transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code "encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv v6 "rwkv.blocks.{bid}.ln2", # rwkv6
"model.blocks.{bid}.ln2", # rwkv v7 "model.blocks.{bid}.ln2", # rwkv7
), ),
# Attention query-key-value # Attention query-key-value
@ -469,159 +469,178 @@ class TensorNameMap:
MODEL_TENSOR.TIME_MIX_W0: ( MODEL_TENSOR.TIME_MIX_W0: (
"model.blocks.{bid}.attention.w0", # rwkv7 "model.blocks.{bid}.attention.w0", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.w0", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_W1: ( MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
"model.blocks.{bid}.attention.w1" # rwkv7 "model.blocks.{bid}.attention.w1", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.w1", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_W2: ( MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
"model.blocks.{bid}.attention.w2" # rwkv7 "model.blocks.{bid}.attention.w2", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.w2", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_A0: ( MODEL_TENSOR.TIME_MIX_A0: (
"model.blocks.{bid}.attention.a0", # rwkv7 "model.blocks.{bid}.attention.a0", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.a0", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_A1: ( MODEL_TENSOR.TIME_MIX_A1: (
"model.blocks.{bid}.attention.a1", # rwkv7 "model.blocks.{bid}.attention.a1", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.a1", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_A2: ( MODEL_TENSOR.TIME_MIX_A2: (
"model.blocks.{bid}.attention.a2", # rwkv7 "model.blocks.{bid}.attention.a2", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.a2", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_V0: ( MODEL_TENSOR.TIME_MIX_V0: (
"model.blocks.{bid}.attention.v0", # rwkv7 "model.blocks.{bid}.attention.v0", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.v0", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_V1: ( MODEL_TENSOR.TIME_MIX_V1: (
"model.blocks.{bid}.attention.v1", # rwkv7 "model.blocks.{bid}.attention.v1", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.v1", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_V2: ( MODEL_TENSOR.TIME_MIX_V2: (
"model.blocks.{bid}.attention.v2", # rwkv7 "model.blocks.{bid}.attention.v2", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.v2", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_G1: ( MODEL_TENSOR.TIME_MIX_G1: (
"model.blocks.{bid}.attention.g1", # rwkv7 "model.blocks.{bid}.attention.g1", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.g1", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_G2: ( MODEL_TENSOR.TIME_MIX_G2: (
"model.blocks.{bid}.attention.g2", # rwkv7 "model.blocks.{bid}.attention.g2", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.g2", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_K_K: ( MODEL_TENSOR.TIME_MIX_K_K: (
"model.blocks.{bid}.attention.k_k", # rwkv7 "model.blocks.{bid}.attention.k_k", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.k_k", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_K_A: ( MODEL_TENSOR.TIME_MIX_K_A: (
"model.blocks.{bid}.attention.k_a", # rwkv7 "model.blocks.{bid}.attention.k_a", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.k_a", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_R_K: ( MODEL_TENSOR.TIME_MIX_R_K: (
"model.blocks.{bid}.attention.r_k", # rwkv7 "model.blocks.{bid}.attention.r_k", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.r_k", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_LERP_X: ( MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_K: ( MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_V: ( MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_R: ( MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_G: ( MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_W: ( MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_FIRST: ( MODEL_TENSOR.TIME_MIX_FIRST: (
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6 "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
), ),
MODEL_TENSOR.TIME_MIX_DECAY: ( MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay", # rwkv6
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W1: ( MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W2: ( MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_KEY: ( MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv v6 "rwkv.blocks.{bid}.attention.key", # rwkv6
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.key", # rwkv v7 "model.blocks.{bid}.attention.key", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.key.weight", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_VALUE: ( MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv v6 "rwkv.blocks.{bid}.attention.value", # rwkv6
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.value", # rwkv v7 "model.blocks.{bid}.attention.value", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.value.weight", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv v6 "rwkv.blocks.{bid}.attention.receptance", # rwkv6
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.receptance", # rwkv v7 "model.blocks.{bid}.attention.receptance", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.receptance.weight", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_GATE: ( MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv v6 "rwkv.blocks.{bid}.attention.gate", # rwkv6
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2 "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
"model.layers.{bid}.self_attn.time_mixer.gate.weight", # arwkv7
), ),
MODEL_TENSOR.TIME_MIX_LN: ( MODEL_TENSOR.TIME_MIX_LN: (
"rwkv.blocks.{bid}.attention.ln_x", # rwkv v6 "rwkv.blocks.{bid}.attention.ln_x", # rwkv6
"model.blocks.{bid}.attention.ln_x" # rwkv v7 "model.blocks.{bid}.attention.ln_x" # rwkv7
), ),
MODEL_TENSOR.TIME_MIX_OUTPUT: ( MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv "rwkv.blocks.{bid}.attention.output", # rwkv
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.output", # rwkv v7 "model.blocks.{bid}.attention.output", # rwkv7
"model.layers.{bid}.self_attn.time_mixer.output.weight", # arwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
"model.blocks.{bid}.feed_forward.x_k", # rwkv v7 "model.blocks.{bid}.feed_forward.x_k", # rwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
), ),
MODEL_TENSOR.CHANNEL_MIX_KEY: ( MODEL_TENSOR.CHANNEL_MIX_KEY: (
"rwkv.blocks.{bid}.feed_forward.key", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.key", # rwkv6
"model.blocks.{bid}.feed_forward.key", # rwkv v7 "model.blocks.{bid}.feed_forward.key", # rwkv7
), ),
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
@ -629,8 +648,8 @@ class TensorNameMap:
), ),
MODEL_TENSOR.CHANNEL_MIX_VALUE: ( MODEL_TENSOR.CHANNEL_MIX_VALUE: (
"rwkv.blocks.{bid}.feed_forward.value", # rwkv v6 "rwkv.blocks.{bid}.feed_forward.value", # rwkv6
"model.blocks.{bid}.feed_forward.value", # rwkv v7 "model.blocks.{bid}.feed_forward.value", # rwkv7
), ),
MODEL_TENSOR.ATTN_Q_A: ( MODEL_TENSOR.ATTN_Q_A: (

View file

@ -59,6 +59,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" }, { LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
@ -1256,6 +1257,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
}, },
}, },
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {

View file

@ -63,6 +63,7 @@ enum llm_arch {
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7, LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,

View file

@ -1211,17 +1211,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_RWKV7: case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay);
ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr);
ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false);
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
// TODO: Add variants case 28: type = LLM_TYPE_7B; break; // ARWKV7
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
@ -3366,6 +3368,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
} }
} break;
case LLM_ARCH_ARWKV7:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
const int n_lora_decay = hparams.n_lora_decay;
const int n_lora_iclr = hparams.n_lora_iclr;
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
const int n_lora_gate = hparams.n_lora_gate;
const int attn_hidden_size = n_embd;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 6}, 0);
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break; } break;
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
{ {
@ -3953,6 +4011,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7: case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_WAVTOKENIZER_DEC:
return LLAMA_ROPE_TYPE_NONE; return LLAMA_ROPE_TYPE_NONE;
@ -4107,6 +4166,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
case LLM_ARCH_RWKV6: return true; case LLM_ARCH_RWKV6: return true;
case LLM_ARCH_RWKV6QWEN2: return true; case LLM_ARCH_RWKV6QWEN2: return true;
case LLM_ARCH_RWKV7: return true; case LLM_ARCH_RWKV7: return true;
case LLM_ARCH_ARWKV7: return true;
default: return false; default: return false;
} }
} }

View file

@ -1097,7 +1097,11 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
); );
} }
struct ggml_tensor * g = ggml_mul_mat(ctx, layer->time_mix_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->time_mix_g1, xg))); struct ggml_tensor * g = nullptr;
if (layer->time_mix_g1 && layer->time_mix_g2) {
g = ggml_mul_mat(ctx, layer->time_mix_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->time_mix_g1, xg)));
}
struct ggml_tensor * a = ggml_sigmoid(ctx, struct ggml_tensor * a = ggml_sigmoid(ctx,
ggml_add( ggml_add(
ctx, ctx,
@ -1122,6 +1126,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
if (layer->time_mix_ln && layer->time_mix_ln_b) {
// group norm with head_count groups // group norm with head_count groups
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens); cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f); cur = ggml_norm(ctx, cur, 64e-5f);
@ -1129,12 +1134,17 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
// Convert back to regular vectors. // Convert back to regular vectors.
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
} else {
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
}
struct ggml_tensor * rk = ggml_sum_rows(ctx, struct ggml_tensor * rk = ggml_sum_rows(ctx,
ggml_mul(ctx, ggml_mul(ctx, k, r), ggml_reshape_2d(ctx, layer->time_mix_r_k, head_size, head_count))); ggml_mul(ctx, ggml_mul(ctx, k, r), ggml_reshape_2d(ctx, layer->time_mix_r_k, head_size, head_count)));
cur = ggml_add(ctx, cur, ggml_reshape_2d(ctx, ggml_mul(ctx, v, rk), n_embd, n_tokens)); cur = ggml_add(ctx, cur, ggml_reshape_2d(ctx, ggml_mul(ctx, v, rk), n_embd, n_tokens));
if (g) {
cur = ggml_mul(ctx, cur, g); cur = ggml_mul(ctx, cur, g);
}
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
return cur; return cur;
@ -8015,6 +8025,114 @@ struct llm_build_context {
return gf; return gf;
} }
ggml_cgraph * build_arwkv7() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / hparams.token_shift_count);
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * value_first_layer = nullptr;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
// (ab)using the KV cache to store the states
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
gf, kv_self.k_l[il], state_copy, state_mask,
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
gf, kv_self.v_l[il], state_copy, state_mask,
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
struct ggml_tensor * x_prev = ggml_concat(
ctx0,
token_shift,
ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
1
);
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
)
);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv7_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, value_first_layer, hparams.wkv_head_size));
ggml_build_forward_expand(gf, ffn_inp);
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
wkv_states,
ggml_view_1d(
ctx0,
kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
)
)
);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
// ref: https://github.com/facebookresearch/chameleon // ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes: // based on the original build_llama() function, changes:
// * qk-norm // * qk-norm
@ -8632,6 +8750,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_rwkv7(); result = llm.build_rwkv7();
} break; } break;
case LLM_ARCH_ARWKV7:
{
result = llm.build_arwkv7();
} break;
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
{ {
result = llm.build_chameleon(); result = llm.build_chameleon();