convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-07-31 16:05:23 +08:00
parent 20f1789dfb
commit 8d2eca3507
3 changed files with 332 additions and 82 deletions

View file

@ -2716,6 +2716,79 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2 model_arch = gguf.MODEL_ARCH.STARCODER2
@Model.register("Rwkv6ForCausalLM")
class RwkvModel(Model):
model_arch = gguf.MODEL_ARCH.RWKV
def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)
tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
x = eval(line[line.index(' '):line.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(line[line.rindex(' '):])
token_text: str = ""
for b in x:
token_text += f"\\x{b:02x}"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(remainder):
tokens.append(f"<unused {i}>".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)
self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_size = self.hparams["head_size"]
hidden_size = self.hparams["hidden_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(0)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp
# temporarlily reuse mamba hparams
self.gguf_writer.add_ssm_inner_size(hidden_size)
self.gguf_writer.add_ssm_conv_kernel(3)
self.gguf_writer.add_ssm_state_size(head_size)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
data_torch = data_torch.transpose(0, 1)
if new_name.endswith("time_mix_w2.weight"):
data_torch = data_torch.permute(0, 2, 1)
rescale_every_n_layers = self.hparams["rescale_every"]
if rescale_every_n_layers > 0:
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
yield (new_name, data_torch)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model): class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA

View file

@ -207,6 +207,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto() GEMMA = auto()
GEMMA2 = auto() GEMMA2 = auto()
STARCODER2 = auto() STARCODER2 = auto()
RWKV = auto()
MAMBA = auto() MAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
@ -270,6 +271,29 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_OUT = auto() SSM_OUT = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
TIME_MIX_LERP_X = auto()
TIME_MIX_LERP_K = auto()
TIME_MIX_LERP_V = auto()
TIME_MIX_LERP_R = auto()
TIME_MIX_LERP_G = auto()
TIME_MIX_LERP_W = auto()
TIME_MIX_FIRST = auto()
TIME_MIX_DECAY = auto()
TIME_MIX_DECAY_W1 = auto()
TIME_MIX_DECAY_W2 = auto()
TIME_MIX_KEY = auto()
TIME_MIX_VALUE = auto()
TIME_MIX_RECEPTANCE = auto()
TIME_MIX_GATE = auto()
TIME_MIX_LN = auto()
TIME_MIX_OUTPUT = auto()
CHANNEL_MIX_LERP_K = auto()
CHANNEL_MIX_LERP_R = auto()
CHANNEL_MIX_KEY = auto()
CHANNEL_MIX_RECEPTANCE = auto()
CHANNEL_MIX_VALUE = auto()
ATTN_Q_A = auto() ATTN_Q_A = auto()
ATTN_Q_B = auto() ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto() ATTN_KV_A_MQA = auto()
@ -337,6 +361,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV: "rwkv",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
@ -400,6 +425,29 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1",
MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2",
MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key",
MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value",
MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance",
MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate",
MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln",
MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output",
MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k",
MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r",
MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key",
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance",
MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value",
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
@ -856,6 +904,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.RWKV: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_LERP_X,
MODEL_TENSOR.TIME_MIX_LERP_K,
MODEL_TENSOR.TIME_MIX_LERP_V,
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
MODEL_TENSOR.TIME_MIX_DECAY_W2,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_GATE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
MODEL_TENSOR.CHANNEL_MIX_LERP_R,
MODEL_TENSOR.CHANNEL_MIX_KEY,
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
MODEL_ARCH.MAMBA: [ MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -27,6 +27,7 @@ class TensorNameMap:
"embedding.word_embeddings", # chatglm "embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm "transformer.token_embeddings", # openelm
"shared", # t5 "shared", # t5
"rwkv.embeddings", # rwkv
), ),
# Token type embeddings # Token type embeddings
@ -40,6 +41,7 @@ class TensorNameMap:
"embeddings.LayerNorm", # bert "embeddings.LayerNorm", # bert
"emb_ln", # nomic-bert "emb_ln", # nomic-bert
"transformer.norm", # openelm "transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv
), ),
# Position embeddings # Position embeddings
@ -57,6 +59,7 @@ class TensorNameMap:
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2 "lm_head.linear", # phi2
"output_layer", # chatglm "output_layer", # chatglm
"head", # rwkv
), ),
# Output norm # Output norm
@ -76,6 +79,7 @@ class TensorNameMap:
"encoder.final_layernorm", # chatglm "encoder.final_layernorm", # chatglm
"transformer.norm", # openelm "transformer.norm", # openelm
"model.norm", # nemotron "model.norm", # nemotron
"rwkv.ln_out", # rwkv
), ),
# Rope frequencies # Rope frequencies
@ -108,12 +112,14 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm "encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm "transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv
), ),
# Attention norm 2 # Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: ( MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b "transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code "encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv
), ),
# Attention query-key-value # Attention query-key-value
@ -434,6 +440,98 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",
), ),
MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_FIRST: (
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
),
MODEL_TENSOR.TIME_MIX_KEY: (
"rwkv.blocks.{bid}.attention.key", # rwkv
),
MODEL_TENSOR.TIME_MIX_VALUE: (
"rwkv.blocks.{bid}.attention.value", # rwkv
),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv
),
MODEL_TENSOR.TIME_MIX_GATE: (
"rwkv.blocks.{bid}.attention.gate", # rwkv
),
MODEL_TENSOR.TIME_MIX_LN: (
"rwkv.blocks.{bid}.attention.ln_x", # rwkv
),
MODEL_TENSOR.TIME_MIX_OUTPUT: (
"rwkv.blocks.{bid}.attention.output", # rwkv
),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
),
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
),
MODEL_TENSOR.CHANNEL_MIX_KEY: (
"rwkv.blocks.{bid}.feed_forward.key", # rwkv
),
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
),
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
),
MODEL_TENSOR.ATTN_Q_A: ( MODEL_TENSOR.ATTN_Q_A: (
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2 "model.layers.{bid}.self_attn.q_a_proj", # deepseek2
), ),