llama: add support for QRWKV6 model architecture (#11001)

llama: add support for QRWKV6 model architecture (#11001)

* WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV: Some graph simplification

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Add support for RWKV6Qwen2 with cpu and cuda GLA

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix some typos

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix wkv test & add gla test

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Fix cuda warning

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update README.md

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Update ggml/src/ggml-cuda/gla.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Fix fused lerp weights loading with RWKV6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* better sanity check skipping for QRWKV6 in llama-quant

thanks @compilade

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: compilade <git@compilade.net>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
Molly Sophia 2025-01-10 09:58:08 +08:00 committed by GitHub
parent c6860cc734
commit ee7136c6d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 862 additions and 124 deletions

View file

@ -115,6 +115,7 @@ class Keys:
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@ -255,6 +256,7 @@ class MODEL_ARCH(IntEnum):
GEMMA2 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
@ -334,6 +336,7 @@ class MODEL_TENSOR(IntEnum):
TIME_MIX_LERP_V = auto()
TIME_MIX_LERP_R = auto()
TIME_MIX_LERP_G = auto()
TIME_MIX_LERP_FUSED = auto()
TIME_MIX_LERP_W = auto()
TIME_MIX_FIRST = auto()
TIME_MIX_DECAY = auto()
@ -440,6 +443,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
@ -519,6 +523,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
@ -1103,6 +1108,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
@ -1119,6 +1125,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
MODEL_ARCH.RWKV6QWEN2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_LERP_X,
MODEL_TENSOR.TIME_MIX_LERP_K,
MODEL_TENSOR.TIME_MIX_LERP_V,
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
MODEL_TENSOR.TIME_MIX_DECAY_W2,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_GATE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -743,6 +743,9 @@ class GGUFWriter:
def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
def add_token_shift_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View file

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe olmo2
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@ -464,34 +464,42 @@ class TensorNameMap:
MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_FIRST: (
@ -500,30 +508,37 @@ class TensorNameMap:
MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv
"rwkv.blocks.{bid}.attention.key", # rwkv
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv
"rwkv.blocks.{bid}.attention.value", # rwkv
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv
"rwkv.blocks.{bid}.attention.gate", # rwkv
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LN: (
@ -531,7 +546,8 @@ class TensorNameMap:
),
MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv
"rwkv.blocks.{bid}.attention.output", # rwkv
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (