diff --git a/src/llama.cpp b/src/llama.cpp index 9bf8d65f3..ef6c4632b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9384,6 +9384,7 @@ static struct ggml_tensor * llm_build_mamba( } static struct ggml_tensor * llm_build_time_mix_rwkv6( + struct llama_context & lctx, struct ggml_context * ctx, const struct llama_layer * layer, struct ggml_tensor * cur, @@ -9481,12 +9482,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( cur ); - struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); - struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); - struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); + struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); + struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); + struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); struct ggml_tensor * g = ggml_silu( ctx, - ggml_mul_mat(ctx, layer->time_mix_gate, xg) + llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg) ); struct ggml_tensor * w = ggml_mul_mat( @@ -9516,12 +9517,13 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); cur = ggml_mul(ctx, cur, g); - cur = ggml_mul_mat(ctx, layer->time_mix_output, cur); + cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs); } static struct ggml_tensor * llm_build_channel_mix_rwkv6( + struct llama_context & lctx, struct ggml_context * ctx, const struct llama_layer * layer, struct ggml_tensor * cur, @@ -9530,15 +9532,15 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6( struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur); struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur); - struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr)); + struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr)); struct ggml_tensor * k = ggml_sqr( ctx, ggml_relu( ctx, - ggml_mul_mat(ctx, layer->channel_mix_key, xk) + llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk) ) ); - return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k)); + return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k)); } struct llm_build_context { @@ -15109,7 +15111,7 @@ struct llm_build_context { 1 ); - cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states)); + cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); ggml_build_forward_expand(gf, cur); ggml_build_forward_expand( gf, @@ -15132,7 +15134,7 @@ struct llm_build_context { ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), 1 ); - cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev)); + cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev)); ggml_build_forward_expand(gf, cur); struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); @@ -15166,7 +15168,7 @@ struct llm_build_context { cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur);