Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-12 14:30:04 +08:00
parent 5afa3eff3a
commit 12fbe1ade2
3 changed files with 18 additions and 18 deletions

View file

@ -2719,7 +2719,7 @@ class StarCoder2Model(Model):
@Model.register("Rwkv6ForCausalLM")
class RwkvModel(Model):
model_arch = gguf.MODEL_ARCH.RWKV
model_arch = gguf.MODEL_ARCH.RWKV6
def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()

View file

@ -211,7 +211,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
RWKV = auto()
RWKV6 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
@ -365,7 +365,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV: "rwkv",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
@ -908,7 +908,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.RWKV: [
MODEL_ARCH.RWKV6: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -212,7 +212,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV,
LLM_ARCH_RWKV6,
LLM_ARCH_UNKNOWN,
};
@ -260,7 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV, "rwkv" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1371,7 +1371,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
LLM_ARCH_RWKV,
LLM_ARCH_RWKV6,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
@ -5903,7 +5903,7 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
@ -8338,7 +8338,7 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -9361,7 +9361,7 @@ static struct ggml_tensor * llm_build_mamba(
return cur;
}
static struct ggml_tensor * llm_build_time_mix(
static struct ggml_tensor * llm_build_time_mix_rwkv6(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
@ -9522,7 +9522,7 @@ static struct ggml_tensor * llm_build_time_mix(
return ggml_mul_mat(ctx, layer->time_mix_output, cur);
}
static struct ggml_tensor * llm_build_channel_mix(
static struct ggml_tensor * llm_build_channel_mix_rwkv6(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
@ -15064,7 +15064,7 @@ struct llm_build_context {
return gf;
}
ggml_cgraph * build_rwkv() {
ggml_cgraph * build_rwkv6() {
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// Token shift state dimensions should be 2 * n_emb
@ -15112,7 +15112,7 @@ struct llm_build_context {
n_embd, n_tokens
);
cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
@ -15148,7 +15148,7 @@ struct llm_build_context {
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
n_embd, n_tokens
);
cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
@ -15444,9 +15444,9 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_exaone();
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
result = llm.build_rwkv();
result = llm.build_rwkv6();
} break;
default:
GGML_ABORT("fatal error");
@ -18477,7 +18477,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
@ -18646,7 +18646,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV: return true;
case LLM_ARCH_RWKV6: return true;
default: return false;
}
}