llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-30 12:11:31 +08:00
parent 7444046c47
commit 7f2ef56639

View file

@ -9384,6 +9384,7 @@ static struct ggml_tensor * llm_build_mamba(
} }
static struct ggml_tensor * llm_build_time_mix_rwkv6( static struct ggml_tensor * llm_build_time_mix_rwkv6(
struct llama_context & lctx,
struct ggml_context * ctx, struct ggml_context * ctx,
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * cur, struct ggml_tensor * cur,
@ -9481,12 +9482,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
cur cur
); );
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
struct ggml_tensor * g = ggml_silu( struct ggml_tensor * g = ggml_silu(
ctx, ctx,
ggml_mul_mat(ctx, layer->time_mix_gate, xg) llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
); );
struct ggml_tensor * w = ggml_mul_mat( struct ggml_tensor * w = ggml_mul_mat(
@ -9516,12 +9517,13 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
cur = ggml_mul(ctx, cur, g); cur = ggml_mul(ctx, cur, g);
cur = ggml_mul_mat(ctx, layer->time_mix_output, cur); cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs); return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
} }
static struct ggml_tensor * llm_build_channel_mix_rwkv6( static struct ggml_tensor * llm_build_channel_mix_rwkv6(
struct llama_context & lctx,
struct ggml_context * ctx, struct ggml_context * ctx,
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * cur, struct ggml_tensor * cur,
@ -9530,15 +9532,15 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6(
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur); struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur); struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr)); struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr( struct ggml_tensor * k = ggml_sqr(
ctx, ctx,
ggml_relu( ggml_relu(
ctx, ctx,
ggml_mul_mat(ctx, layer->channel_mix_key, xk) llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
) )
); );
return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k)); return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
} }
struct llm_build_context { struct llm_build_context {
@ -15109,7 +15111,7 @@ struct llm_build_context {
1 1
); );
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states)); cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
@ -15132,7 +15134,7 @@ struct llm_build_context {
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1 1
); );
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev)); cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@ -15166,7 +15168,7 @@ struct llm_build_context {
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
cur = ggml_mul_mat(ctx0, model.output, cur); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);