RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-12-28 17:30:53 +08:00
parent fab0aa7b1a
commit bc930cd59a
8 changed files with 75 additions and 98 deletions

View file

@ -326,6 +326,7 @@ class Model:
gguf.MODEL_TENSOR.TIME_MIX_W2,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
)
@ -3256,6 +3257,7 @@ class Rwkv6Model(Model):
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
lerp_weights: dict[int, dict[str, Tensor]] = {}
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
@ -3271,16 +3273,32 @@ class Rwkv6Model(Model):
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
data_torch = data_torch.squeeze()
rescale_every_n_layers = self.hparams["rescale_every"]
if rescale_every_n_layers > 0:
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
try:
rescale_every_n_layers = self.hparams["rescale_every"]
if rescale_every_n_layers > 0:
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
except KeyError:
pass
# concat time_mix_lerp weights to reduce some cpu overhead
# also reduces the number of tensors in the model
if bid is not None and "time_mix_lerp" in new_name and not "time_mix_lerp_x" in new_name:
try:
self.lerp_weights[bid][new_name] = data_torch
except KeyError:
self.lerp_weights[bid] = {new_name: data_torch}
if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
yield (new_name, data)
return
yield (new_name, data_torch)
@Model.register("RWKV6Qwen2ForCausalLM")
class RWKV6Qwen2Model(Model):
class RWKV6Qwen2Model(Rwkv6Model):
model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
def set_vocab(self):
@ -3320,21 +3338,17 @@ class RWKV6Qwen2Model(Model):
self.gguf_writer.add_head_count(0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
data_torch = data_torch.transpose(0, 1)
if new_name.endswith("time_mix_w2.weight"):
data_torch = data_torch.permute(0, 2, 1)
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
data_torch = data_torch.squeeze()
yield (new_name, data_torch)
for new_name, data in super().modify_tensors(data_torch, name, bid):
if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
data = data.view(5, -1, data.shape[-1])
# rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
# permute them here to avoid code changes
data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
if "w2" in new_name:
data = data.view(5, -1, data.shape[-1])
yield (new_name, data)
continue
yield (new_name, data)
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")

View file

@ -331,6 +331,7 @@ class MODEL_TENSOR(IntEnum):
TIME_MIX_LERP_V = auto()
TIME_MIX_LERP_R = auto()
TIME_MIX_LERP_G = auto()
TIME_MIX_LERP_FUSED = auto()
TIME_MIX_LERP_W = auto()
TIME_MIX_FIRST = auto()
TIME_MIX_DECAY = auto()
@ -514,6 +515,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
@ -1080,6 +1082,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
@ -1109,6 +1112,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,

View file

@ -1154,11 +1154,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
{ LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
{ LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
@ -1356,6 +1352,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View file

@ -250,6 +250,7 @@ enum llm_tensor {
LLM_TENSOR_TIME_MIX_LERP_V,
LLM_TENSOR_TIME_MIX_LERP_R,
LLM_TENSOR_TIME_MIX_LERP_G,
LLM_TENSOR_TIME_MIX_LERP_FUSED,
LLM_TENSOR_TIME_MIX_FIRST,
LLM_TENSOR_TIME_MIX_DECAY,
LLM_TENSOR_TIME_MIX_DECAY_W1,

View file

@ -2164,6 +2164,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV6: return true;
case LLM_ARCH_RWKV6QWEN2: return true;
default: return false;
}
}

View file

@ -238,6 +238,7 @@ struct llama_layer {
struct ggml_tensor * time_mix_lerp_v = nullptr;
struct ggml_tensor * time_mix_lerp_r = nullptr;
struct ggml_tensor * time_mix_lerp_g = nullptr;
struct ggml_tensor * time_mix_lerp_fused = nullptr;
struct ggml_tensor * time_mix_first = nullptr;
struct ggml_tensor * time_mix_decay = nullptr;

View file

@ -760,6 +760,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;

View file

@ -2123,11 +2123,13 @@ static bool llm_load_tensors(
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@ -2180,11 +2182,7 @@ static bool llm_load_tensors(
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@ -3337,72 +3335,32 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
xxx
);
struct ggml_tensor *mw, *mk, *mv, *mr, *mg;
if (is_qrwkv) {
// Why the f*** do they change the order here?
mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
if (layer->time_mix_lerp_fused) {
// fusing these weights makes some performance improvement
sx = ggml_reshape_3d(ctx, sx, n_embd, 1, n_tokens);
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
} else {
mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
// for backward compatibility
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
}
struct ggml_tensor * xw = ggml_add(
ctx,
ggml_mul(
ctx,
ggml_add(ctx, mw, layer->time_mix_lerp_w),
sx
),
cur
);
struct ggml_tensor * xk = ggml_add(
ctx,
ggml_mul(
ctx,
ggml_add(ctx, mk, layer->time_mix_lerp_k),
sx
),
cur
);
struct ggml_tensor * xv = ggml_add(
ctx,
ggml_mul(
ctx,
ggml_add(ctx, mv, layer->time_mix_lerp_v),
sx
),
cur
);
struct ggml_tensor * xr = ggml_add(
ctx,
ggml_mul(
ctx,
ggml_add(ctx, mr, layer->time_mix_lerp_r),
sx
),
cur
);
struct ggml_tensor * xg = ggml_add(
ctx,
ggml_mul(
ctx,
ggml_add(ctx, mg, layer->time_mix_lerp_g),
sx
),
cur
);
struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);