rwkv: better handling for models without gate
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
9cad1ca194
commit
39eb446ad6
1 changed files with 5 additions and 3 deletions
|
@ -1057,8 +1057,10 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
|
|||
|
||||
size_t n_tokens = n_seqs * n_seq_tokens;
|
||||
|
||||
bool has_gating = layer->time_mix_g1 && layer->time_mix_g2;
|
||||
|
||||
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
|
||||
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_tokens, 6);
|
||||
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_tokens, layer->time_mix_lerp_fused->ne[2]);
|
||||
sx = ggml_repeat(ctx, sx, dummy);
|
||||
|
||||
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_fused), cur);
|
||||
|
@ -1068,7 +1070,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
|
|||
struct ggml_tensor * xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
struct ggml_tensor * xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
struct ggml_tensor * xa = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
struct ggml_tensor * xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float));
|
||||
struct ggml_tensor * xg = has_gating ? ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
|
||||
|
||||
struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
|
||||
// Assume that there won't be lora adapters on these “lora” matmuls?
|
||||
|
@ -1142,7 +1144,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
|
|||
ggml_mul(ctx, ggml_mul(ctx, k, r), ggml_reshape_2d(ctx, layer->time_mix_r_k, head_size, head_count)));
|
||||
cur = ggml_add(ctx, cur, ggml_reshape_2d(ctx, ggml_mul(ctx, v, rk), n_embd, n_tokens));
|
||||
|
||||
if (g) {
|
||||
if (has_gating) {
|
||||
cur = ggml_mul(ctx, cur, g);
|
||||
}
|
||||
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue