WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-12-24 09:36:30 +08:00
parent f66f582927
commit f298f03970
11 changed files with 376 additions and 30 deletions

View file

@ -3279,6 +3279,64 @@ class Rwkv6Model(Model):
yield (new_name, data_torch) yield (new_name, data_torch)
@Model.register("RWKV6Qwen2ForCausalLM")
class RWKV6Qwen2Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
num_attention_heads = self.hparams["num_attention_heads"]
num_key_value_heads = self.hparams["num_key_value_heads"]
hidden_size = self.hparams["hidden_size"]
head_size = hidden_size // num_attention_heads
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
# special parameters for time_mixing in RWKV6QWEN2
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_token_shift_count(1)
# RWKV6QWEN2 use grouped key/value like GQA
self.gguf_writer.add_head_count_kv(num_key_value_heads)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
data_torch = data_torch.transpose(0, 1)
if new_name.endswith("time_mix_w2.weight"):
data_torch = data_torch.permute(0, 2, 1)
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
data_torch = data_torch.squeeze()
yield (new_name, data_torch)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model): class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA

View file

@ -113,6 +113,7 @@ class Keys:
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
RESIDUAL_SCALE = "{arch}.residual_scale" RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale" EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
class Attention: class Attention:
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT = "{arch}.attention.head_count"
@ -252,6 +253,7 @@ class MODEL_ARCH(IntEnum):
GEMMA2 = auto() GEMMA2 = auto()
STARCODER2 = auto() STARCODER2 = auto()
RWKV6 = auto() RWKV6 = auto()
RWKV6QWEN2 = auto()
MAMBA = auto() MAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
@ -434,6 +436,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
@ -1093,6 +1096,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE, MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
MODEL_TENSOR.CHANNEL_MIX_VALUE, MODEL_TENSOR.CHANNEL_MIX_VALUE,
], ],
MODEL_ARCH.RWKV6QWEN2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_LERP_X,
MODEL_TENSOR.TIME_MIX_LERP_K,
MODEL_TENSOR.TIME_MIX_LERP_V,
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
MODEL_TENSOR.TIME_MIX_DECAY_W2,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_GATE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MAMBA: [ MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -736,6 +736,9 @@ class GGUFWriter:
def add_wkv_head_size(self, size: int) -> None: def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
def add_token_shift_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View file

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon "transformer.word_embeddings", # falcon
"word_embeddings", # bloom "word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert "embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
@ -457,34 +457,42 @@ class TensorNameMap:
MODEL_TENSOR.TIME_MIX_W1: ( MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_W2: ( MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_X: ( MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_K: ( MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_V: ( MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_R: ( MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_G: ( MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LERP_W: ( MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_FIRST: ( MODEL_TENSOR.TIME_MIX_FIRST: (
@ -493,30 +501,37 @@ class TensorNameMap:
MODEL_TENSOR.TIME_MIX_DECAY: ( MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W1: ( MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_DECAY_W2: ( MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_KEY: ( MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv "rwkv.blocks.{bid}.attention.key", # rwkv
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_VALUE: ( MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv "rwkv.blocks.{bid}.attention.value", # rwkv
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv "rwkv.blocks.{bid}.attention.receptance", # rwkv
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_GATE: ( MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv "rwkv.blocks.{bid}.attention.gate", # rwkv
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
), ),
MODEL_TENSOR.TIME_MIX_LN: ( MODEL_TENSOR.TIME_MIX_LN: (
@ -524,7 +539,8 @@ class TensorNameMap:
), ),
MODEL_TENSOR.TIME_MIX_OUTPUT: ( MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv "rwkv.blocks.{bid}.attention.output", # rwkv
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
), ),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( MODEL_TENSOR.CHANNEL_MIX_LERP_K: (

View file

@ -55,6 +55,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
@ -102,6 +103,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -1142,6 +1144,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
}, },
}, },
{
LLM_ARCH_RWKV6QWEN2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
{ LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
{ LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
{ LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {

View file

@ -59,6 +59,7 @@ enum llm_arch {
LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
@ -106,6 +107,7 @@ enum llm_kv {
LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE, LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE, LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,

View file

@ -52,7 +52,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
uint32_t llama_hparams::n_embd_k_s() const { uint32_t llama_hparams::n_embd_k_s() const {
if (wkv_head_size != 0) { if (wkv_head_size != 0) {
// for RWKV models // for RWKV models
return 2 * n_embd; return token_shift_count * n_embd;
} }
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1

View file

@ -68,6 +68,7 @@ struct llama_hparams {
uint32_t time_mix_extra_dim = 0; uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0; uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0; uint32_t wkv_head_size = 0;
uint32_t token_shift_count = 2;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;

View file

@ -1017,12 +1017,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
} }
} break; } break;
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1_6B; break; case 24: model.type = e_model::MODEL_1_6B; break;
@ -1033,6 +1036,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} break; } break;
case 61: model.type = e_model::MODEL_14B; break; case 61: model.type = e_model::MODEL_14B; break;
case 64: model.type = e_model::MODEL_32B; break;
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;

View file

@ -239,14 +239,17 @@ struct llama_layer {
struct ggml_tensor * time_mix_lerp_r = nullptr; struct ggml_tensor * time_mix_lerp_r = nullptr;
struct ggml_tensor * time_mix_lerp_g = nullptr; struct ggml_tensor * time_mix_lerp_g = nullptr;
struct ggml_tensor * time_mix_first = nullptr; struct ggml_tensor * time_mix_first = nullptr;
struct ggml_tensor * time_mix_decay = nullptr; struct ggml_tensor * time_mix_decay = nullptr;
struct ggml_tensor * time_mix_decay_w1 = nullptr; struct ggml_tensor * time_mix_decay_w1 = nullptr;
struct ggml_tensor * time_mix_decay_w2 = nullptr; struct ggml_tensor * time_mix_decay_w2 = nullptr;
struct ggml_tensor * time_mix_key = nullptr; struct ggml_tensor * time_mix_key = nullptr;
struct ggml_tensor * time_mix_value = nullptr; struct ggml_tensor * time_mix_key_b = nullptr;
struct ggml_tensor * time_mix_receptance = nullptr; struct ggml_tensor * time_mix_value = nullptr;
struct ggml_tensor * time_mix_gate = nullptr; struct ggml_tensor * time_mix_value_b = nullptr;
struct ggml_tensor * time_mix_receptance = nullptr;
struct ggml_tensor * time_mix_receptance_b = nullptr;
struct ggml_tensor * time_mix_gate = nullptr;
struct ggml_tensor * time_mix_ln = nullptr; struct ggml_tensor * time_mix_ln = nullptr;
struct ggml_tensor * time_mix_ln_b = nullptr; struct ggml_tensor * time_mix_ln_b = nullptr;

View file

@ -2151,6 +2151,63 @@ static bool llm_load_tensors(
} }
} break; } break;
case LLM_ARCH_RWKV6QWEN2:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
const int time_decay_extra_dim = hparams.time_decay_extra_dim;
const int head_size = hparams.wkv_head_size;
const int attn_hidden_size = n_embd;
const int n_head_kv = hparams.n_head_kv();
int attn_key_value_size;
if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
attn_key_value_size = attn_hidden_size;
} else {
attn_key_value_size = n_head_kv * head_size;
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
// optional bias tensors
layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
{ {
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -3238,7 +3295,10 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * x_prev, struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state) { struct ggml_tensor ** wkv_state,
size_t head_count_kv,
bool key_decay,
bool skip_groupnorm) {
size_t n_embd = cur->ne[0]; size_t n_embd = cur->ne[0];
size_t n_seq_tokens = cur->ne[1]; size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2]; size_t n_seqs = cur->ne[2];
@ -3248,6 +3308,8 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
size_t n_tokens = n_seqs * n_seq_tokens; size_t n_tokens = n_seqs * n_seq_tokens;
bool is_qrwkv = layer->time_mix_first == nullptr;
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens); sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens);
@ -3332,14 +3394,32 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
cur cur
); );
struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);
if (layer->time_mix_receptance_b) {
r = ggml_add(ctx, r, layer->time_mix_receptance_b);
}
if (layer->time_mix_key_b) {
k = ggml_add(ctx, k, layer->time_mix_key_b);
}
if (layer->time_mix_value_b) {
v = ggml_add(ctx, v, layer->time_mix_value_b);
}
r = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, head_count, n_tokens);
k = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), head_size, head_count_kv, n_tokens);
v = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, head_count_kv, n_tokens);
struct ggml_tensor * g = ggml_silu( struct ggml_tensor * g = ggml_silu(
ctx, ctx,
llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg) llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
); );
if (head_count_kv != head_count) {
GGML_ASSERT(head_count % head_count_kv == 0);
k = ggml_repeat(ctx, k, r);
v = ggml_repeat(ctx, v, r);
}
struct ggml_tensor * w = ggml_mul_mat( struct ggml_tensor * w = ggml_mul_mat(
ctx, ctx,
layer->time_mix_decay_w2, layer->time_mix_decay_w2,
@ -3353,21 +3433,36 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
k = ggml_transpose(ctx, k); r = ggml_reshape_4d(ctx, r, 1, head_size, head_count, n_tokens);
v = ggml_transpose(ctx, v); k = ggml_reshape_4d(ctx, k, 1, head_size, head_count, n_tokens);
r = ggml_transpose(ctx, r); v = ggml_reshape_4d(ctx, v, 1, head_size, head_count, n_tokens);
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); if (is_qrwkv) {
// k = k * (1 - w)
k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
}
k = ggml_transpose(ctx, k);
struct ggml_tensor * wkv_output;
if (!layer->time_mix_first) {
// TODO: GatedLinearAttention
} else {
wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
}
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
// group norm with head_count groups if (!skip_groupnorm) {
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens); // group norm with head_count groups
cur = ggml_norm(ctx, cur, 64e-5f); cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f);
// Convert back to regular vectors. // Convert back to regular vectors.
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
} else {
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
}
cur = ggml_mul(ctx, cur, g); cur = ggml_mul(ctx, cur, g);
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
@ -9785,7 +9880,7 @@ struct llm_build_context {
1 1
); );
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1], false, false));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
@ -9852,6 +9947,103 @@ struct llm_build_context {
return gf; return gf;
} }
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
ggml_cgraph * build_rwkv6qwen2() {
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask();
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
// (ab)using the KV cache to store the states
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
gf, kv_self.k_l[il], state_copy, state_mask,
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
gf, kv_self.v_l[il], state_copy, state_mask,
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
struct ggml_tensor * x_prev = ggml_concat(
ctx0,
token_shift,
ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
1
);
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
wkv_states,
ggml_view_1d(
ctx0,
kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
)
)
);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv(), true, hparams.wkv_skip_groupnorm));
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
// ref: https://github.com/facebookresearch/chameleon // ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes: // based on the original build_llama() function, changes:
// * qk-norm // * qk-norm
@ -10456,6 +10648,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_rwkv6(); result = llm.build_rwkv6();
} break; } break;
case LLM_ARCH_RWKV6QWEN2:
{
result = llm.build_rwkv6qwen2();
} break;
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
{ {
result = llm.build_chameleon(); result = llm.build_chameleon();