diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 39aff9e39..8bf39cb5d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1899,12 +1899,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rwkv_wkv( struct ggml_context * ctx, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * r, - struct ggml_tensor * tf, - struct ggml_tensor * td, - struct ggml_tensor * state); + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state); // custom operators diff --git a/src/llama.cpp b/src/llama.cpp index ef6c4632b..6c374277d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9383,24 +9383,25 @@ static struct ggml_tensor * llm_build_mamba( return cur; } -static struct ggml_tensor * llm_build_time_mix_rwkv6( - struct llama_context & lctx, - struct ggml_context * ctx, - const struct llama_layer * layer, - struct ggml_tensor * cur, - struct ggml_tensor * x_prev, - struct ggml_tensor ** wkv_state) { - size_t n_embed = cur->ne[0]; +static struct ggml_tensor * llm_build_rwkv6_time_mix( + struct llama_context & lctx, + struct ggml_context * ctx, + const struct llama_layer * layer, + struct ggml_tensor * cur, + struct ggml_tensor * x_prev, + struct ggml_tensor ** wkv_state) { + size_t n_embed = cur->ne[0]; size_t n_seq_tokens = cur->ne[1]; - size_t n_seqs = cur->ne[2]; - size_t head_size = layer->time_mix_first->ne[0]; + size_t n_seqs = cur->ne[2]; + + size_t head_size = layer->time_mix_first->ne[0]; size_t head_count = layer->time_mix_first->ne[1]; size_t n_tokens = n_seqs * n_seq_tokens; struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); - sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens); + sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur); @@ -9498,6 +9499,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw) ) ); + w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)); w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); @@ -9505,6 +9507,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( k = ggml_transpose(ctx, k); v = ggml_transpose(ctx, v); r = ggml_transpose(ctx, r); + struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float)); @@ -9512,6 +9515,7 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( // group norm with head_count groups cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens); cur = ggml_norm(ctx, cur, 64e-5f); + // Convert back to regular vectors. cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); @@ -9522,12 +9526,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs); } -static struct ggml_tensor * llm_build_channel_mix_rwkv6( - struct llama_context & lctx, - struct ggml_context * ctx, - const struct llama_layer * layer, - struct ggml_tensor * cur, - struct ggml_tensor * x_prev) { +static struct ggml_tensor * llm_build_rwkv6_channel_mix( + struct llama_context & lctx, + struct ggml_context * ctx, + const struct llama_layer * layer, + struct ggml_tensor * cur, + struct ggml_tensor * x_prev) { struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur); struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur); @@ -9540,6 +9544,7 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6( llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk) ) ); + return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k)); } @@ -15111,7 +15116,7 @@ struct llm_build_context { 1 ); - cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); + cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); ggml_build_forward_expand(gf, cur); ggml_build_forward_expand( gf, @@ -15134,7 +15139,7 @@ struct llm_build_context { ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), 1 ); - cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev)); + cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev)); ggml_build_forward_expand(gf, cur); struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));