WIP: Add support for rwkv v7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-15 20:43:23 +08:00
parent 5445300758
commit 6dcc21e7f5
14 changed files with 952 additions and 48 deletions

View file

@ -118,22 +118,26 @@ class Keys:
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_LENGTH = "{arch}.attention.key_length"
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
HEAD_COUNT = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_LENGTH = "{arch}.attention.key_length"
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@ -256,6 +260,7 @@ class MODEL_ARCH(IntEnum):
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
RWKV7 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
@ -328,8 +333,20 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
TIME_MIX_A0 = auto()
TIME_MIX_A1 = auto()
TIME_MIX_A2 = auto()
TIME_MIX_V0 = auto()
TIME_MIX_V1 = auto()
TIME_MIX_V2 = auto()
TIME_MIX_G1 = auto()
TIME_MIX_G2 = auto()
TIME_MIX_K_K = auto()
TIME_MIX_K_A = auto()
TIME_MIX_R_K = auto()
TIME_MIX_LERP_X = auto()
TIME_MIX_LERP_K = auto()
TIME_MIX_LERP_V = auto()
@ -443,6 +460,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
@ -515,8 +533,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
@ -1153,6 +1183,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.RWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
MODEL_TENSOR.CHANNEL_MIX_KEY,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -767,6 +767,18 @@ class GGUFWriter:
def add_kv_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
def add_decay_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
def add_iclr_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
def add_value_residual_mix_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
def add_gate_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
def add_relative_attn_buckets_count(self, value: int) -> None:
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)

View file

@ -27,7 +27,8 @@ class TensorNameMap:
"embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm
"shared", # t5
"rwkv.embeddings", # rwkv
"rwkv.embeddings", # rwkv v6
"model.embeddings", # rwkv v7
),
# Token type embeddings
@ -41,7 +42,8 @@ class TensorNameMap:
"embeddings.LayerNorm", # bert
"emb_ln", # nomic-bert
"transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv
"rwkv.blocks.0.pre_ln", # rwkv v6
"model.pre_ln", # rwkv v7
"backbone.norm", # wavtokenizer
),
@ -81,7 +83,8 @@ class TensorNameMap:
"encoder.final_layernorm", # chatglm
"transformer.norm", # openelm
"model.norm", # nemotron
"rwkv.ln_out", # rwkv
"rwkv.ln_out", # rwkv v6
"model.ln_out", # rwkv v7
"backbone.final_layer_norm", # wavtokenizer
),
@ -122,14 +125,16 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv
"rwkv.blocks.{bid}.ln1", # rwkv v6
"model.blocks.{bid}.ln1", # rwkv v7
),
# Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv
"rwkv.blocks.{bid}.ln2", # rwkv v6
"model.blocks.{bid}.ln2", # rwkv v7
),
# Attention query-key-value
@ -462,14 +467,64 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.out_proj",
),
MODEL_TENSOR.TIME_MIX_W0: (
"model.blocks.{bid}.attention.w0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
"model.blocks.{bid}.attention.w1" # rwkv7
),
MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
"model.blocks.{bid}.attention.w2" # rwkv7
),
MODEL_TENSOR.TIME_MIX_A0: (
"model.blocks.{bid}.attention.a0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_A1: (
"model.blocks.{bid}.attention.a1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_A2: (
"model.blocks.{bid}.attention.a2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V0: (
"model.blocks.{bid}.attention.v0", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V1: (
"model.blocks.{bid}.attention.v1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_V2: (
"model.blocks.{bid}.attention.v2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_G1: (
"model.blocks.{bid}.attention.g1", # rwkv7
),
MODEL_TENSOR.TIME_MIX_G2: (
"model.blocks.{bid}.attention.g2", # rwkv7
),
MODEL_TENSOR.TIME_MIX_K_K: (
"model.blocks.{bid}.attention.k_k", # rwkv7
),
MODEL_TENSOR.TIME_MIX_K_A: (
"model.blocks.{bid}.attention.k_a", # rwkv7
),
MODEL_TENSOR.TIME_MIX_R_K: (
"model.blocks.{bid}.attention.r_k", # rwkv7
),
MODEL_TENSOR.TIME_MIX_LERP_X: (
@ -522,36 +577,42 @@ class TensorNameMap:
),
MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv
"rwkv.blocks.{bid}.attention.key", # rwkv v6
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.key", # rwkv v7
),
MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv
"rwkv.blocks.{bid}.attention.value", # rwkv v6
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.value", # rwkv v7
),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
"rwkv.blocks.{bid}.attention.receptance", # rwkv v6
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.receptance", # rwkv v7
),
MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv
"rwkv.blocks.{bid}.attention.gate", # rwkv v6
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LN: (
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
"rwkv.blocks.{bid}.attention.ln_x", # rwkv v6
"model.blocks.{bid}.attention.ln_x" # rwkv v7
),
MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
"model.blocks.{bid}.attention.output", # rwkv v7
),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
"model.blocks.{bid}.feed_forward.x_k", # rwkv v7
),
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
@ -559,7 +620,8 @@ class TensorNameMap:
),
MODEL_TENSOR.CHANNEL_MIX_KEY: (
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
"rwkv.blocks.{bid}.feed_forward.key", # rwkv v6
"model.blocks.{bid}.feed_forward.key", # rwkv v7
),
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
@ -567,7 +629,8 @@ class TensorNameMap:
),
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
"rwkv.blocks.{bid}.feed_forward.value", # rwkv v6
"model.blocks.{bid}.feed_forward.value", # rwkv v7
),
MODEL_TENSOR.ATTN_Q_A: (