WIP: Add support for rwkv v7
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
5445300758
commit
6dcc21e7f5
14 changed files with 952 additions and 48 deletions
|
@ -118,22 +118,26 @@ class Keys:
|
|||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
||||
KEY_LENGTH = "{arch}.attention.key_length"
|
||||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
|
||||
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
|
||||
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
|
||||
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
|
@ -256,6 +260,7 @@ class MODEL_ARCH(IntEnum):
|
|||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
RWKV7 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
|
@ -328,8 +333,20 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
TIME_MIX_A0 = auto()
|
||||
TIME_MIX_A1 = auto()
|
||||
TIME_MIX_A2 = auto()
|
||||
TIME_MIX_V0 = auto()
|
||||
TIME_MIX_V1 = auto()
|
||||
TIME_MIX_V2 = auto()
|
||||
TIME_MIX_G1 = auto()
|
||||
TIME_MIX_G2 = auto()
|
||||
TIME_MIX_K_K = auto()
|
||||
TIME_MIX_K_A = auto()
|
||||
TIME_MIX_R_K = auto()
|
||||
TIME_MIX_LERP_X = auto()
|
||||
TIME_MIX_LERP_K = auto()
|
||||
TIME_MIX_LERP_V = auto()
|
||||
|
@ -443,6 +460,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
MODEL_ARCH.RWKV7: "rwkv7",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
|
@ -515,8 +533,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
|
||||
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
|
||||
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
|
||||
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
|
||||
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
|
||||
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
|
||||
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
|
||||
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
|
||||
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
|
||||
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
|
||||
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||
|
@ -1153,6 +1183,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.RWKV7: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM_2,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_W0,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_A0,
|
||||
MODEL_TENSOR.TIME_MIX_A1,
|
||||
MODEL_TENSOR.TIME_MIX_A2,
|
||||
MODEL_TENSOR.TIME_MIX_V0,
|
||||
MODEL_TENSOR.TIME_MIX_V1,
|
||||
MODEL_TENSOR.TIME_MIX_V2,
|
||||
MODEL_TENSOR.TIME_MIX_G1,
|
||||
MODEL_TENSOR.TIME_MIX_G2,
|
||||
MODEL_TENSOR.TIME_MIX_K_K,
|
||||
MODEL_TENSOR.TIME_MIX_K_A,
|
||||
MODEL_TENSOR.TIME_MIX_R_K,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY,
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -767,6 +767,18 @@ class GGUFWriter:
|
|||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_decay_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_iclr_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
@ -27,7 +27,8 @@ class TensorNameMap:
|
|||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv
|
||||
"rwkv.embeddings", # rwkv v6
|
||||
"model.embeddings", # rwkv v7
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
|
@ -41,7 +42,8 @@ class TensorNameMap:
|
|||
"embeddings.LayerNorm", # bert
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"rwkv.blocks.0.pre_ln", # rwkv v6
|
||||
"model.pre_ln", # rwkv v7
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
|
||||
|
@ -81,7 +83,8 @@ class TensorNameMap:
|
|||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"rwkv.ln_out", # rwkv v6
|
||||
"model.ln_out", # rwkv v7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
|
||||
|
@ -122,14 +125,16 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv v6
|
||||
"model.blocks.{bid}.ln1", # rwkv v7
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv v6
|
||||
"model.blocks.{bid}.ln2", # rwkv v7
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
|
@ -462,14 +467,64 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.blocks.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.w1" # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.w2" # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A0: (
|
||||
"model.blocks.{bid}.attention.a0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A1: (
|
||||
"model.blocks.{bid}.attention.a1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_A2: (
|
||||
"model.blocks.{bid}.attention.a2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V0: (
|
||||
"model.blocks.{bid}.attention.v0", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V1: (
|
||||
"model.blocks.{bid}.attention.v1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_V2: (
|
||||
"model.blocks.{bid}.attention.v2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G1: (
|
||||
"model.blocks.{bid}.attention.g1", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_G2: (
|
||||
"model.blocks.{bid}.attention.g2", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_K: (
|
||||
"model.blocks.{bid}.attention.k_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_K_A: (
|
||||
"model.blocks.{bid}.attention.k_a", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_R_K: (
|
||||
"model.blocks.{bid}.attention.r_k", # rwkv7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
|
@ -522,36 +577,42 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.key", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.value", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.receptance", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LN: (
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.ln_x", # rwkv v6
|
||||
"model.blocks.{bid}.attention.ln_x" # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||
"model.blocks.{bid}.attention.output", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
||||
"model.blocks.{bid}.feed_forward.x_k", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||
|
@ -559,7 +620,8 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.key", # rwkv v6
|
||||
"model.blocks.{bid}.feed_forward.key", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||
|
@ -567,7 +629,8 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv v6
|
||||
"model.blocks.{bid}.feed_forward.value", # rwkv v7
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue