From 385b611d4590c3b761e97d6fd99f710c5b5a7c85 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Wed, 25 Dec 2024 15:29:19 +0800 Subject: [PATCH] RWKV: Some graph simplification Signed-off-by: Molly Sophia --- ggml/include/ggml.h | 9 +++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 4 ++-- ggml/src/ggml-cuda/wkv6.cu | 4 ++-- ggml/src/ggml-sycl/wkv6.cpp | 4 ++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- ggml/src/ggml.c | 18 +++++++++--------- src/llama.cpp | 27 ++++++++++----------------- 7 files changed, 36 insertions(+), 34 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c714fc8c8..748e18d17 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,6 +513,7 @@ extern "C" { GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, GGML_OP_RWKV_WKV6, + GGML_OP_GATED_LINEAR_ATTENTION, GGML_OP_UNARY, @@ -1873,6 +1874,14 @@ extern "C" { struct ggml_tensor * td, struct ggml_tensor * state); + GGML_API struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b7fefb9dd..30a6d333b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos( static void ggml_compute_forward_rwkv_wkv6_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const int64_t T = dst->src[1]->ne[3]; + const int64_t T = dst->src[1]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t HEADS = dst->src[1]->ne[2]; + const int64_t HEADS = dst->src[1]->ne[1]; const int64_t n_seqs = dst->src[5]->ne[1]; const int64_t head_size = C / HEADS; diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu index 42578341a..bbdafbee5 100644 --- a/ggml/src/ggml-cuda/wkv6.cu +++ b/ggml/src/ggml-cuda/wkv6.cu @@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const float * s_d = (const float *)dst->src[5]->data; const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[3]; + const int64_t T = dst->src[0]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[2]; + const int64_t H = dst->src[0]->ne[1]; float * dst_d = (float *)dst->data; diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp index 75ddfb86a..1920456f2 100644 --- a/ggml/src/ggml-sycl/wkv6.cpp +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -107,9 +107,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s float* dst_d = (float*)dst->data; const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[3]; + const int64_t T = dst->src[0]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[2]; + const int64_t H = dst->src[0]->ne[1]; GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); GGML_ASSERT(C % H == 0); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 020e61280..b7187911a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5624,9 +5624,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc } static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { - const size_t seq_length = dst->src[0]->ne[3]; + const size_t seq_length = dst->src[0]->ne[2]; const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_heads = dst->src[0]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1]; ggml_vk_op_f32_rwkv6( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bbe5f482..739c94166 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GET_REL_POS", "ADD_REL_POS", "RWKV_WKV6", + "GATED_LINEAR_ATTN", "UNARY", @@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "get_rel_pos(x)", "add_rel_pos(x)", "rwkv_wkv6(k, v, r, tf, td, s)", + "gated_linear_attn(k, v, q, decay, s)", "unary(x)", @@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4640,15 +4642,13 @@ struct ggml_tensor * ggml_rwkv_wkv6( GGML_ASSERT(ggml_is_contiguous(state)); const int64_t S = k->ne[0]; - const int64_t H = k->ne[2]; - const int64_t n_tokens = k->ne[3]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; const int64_t n_seqs = state->ne[1]; { - GGML_ASSERT(k->ne[1] == 1); - GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); - GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); - // TODO: RWKV v4 and v5 - GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens); + GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens); GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); } diff --git a/src/llama.cpp b/src/llama.cpp index 8d617bde9..35230d1ce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -142,11 +142,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w const int64_t H = 123; const int64_t n_tokens = 123; const int64_t n_seqs = 123; - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); - ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); ggml_tensor * tf = w; - ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); } break; @@ -3296,9 +3296,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( struct ggml_tensor * cur, struct ggml_tensor * x_prev, struct ggml_tensor ** wkv_state, - size_t head_count_kv, - bool key_decay, - bool skip_groupnorm) { + size_t head_count_kv) { size_t n_embd = cur->ne[0]; size_t n_seq_tokens = cur->ne[1]; size_t n_seqs = cur->ne[2]; @@ -3429,19 +3427,14 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( ) ); - w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd)); + w = ggml_add(ctx, w, layer->time_mix_decay); w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); - w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); - - r = ggml_reshape_4d(ctx, r, 1, head_size, head_count, n_tokens); - k = ggml_reshape_4d(ctx, k, 1, head_size, head_count, n_tokens); - v = ggml_reshape_4d(ctx, v, 1, head_size, head_count, n_tokens); + w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens); if (is_qrwkv) { // k = k * (1 - w) k = ggml_sub(ctx, k, ggml_mul(ctx, k, w)); } - k = ggml_transpose(ctx, k); struct ggml_tensor * wkv_output; if (!layer->time_mix_first) { @@ -3452,7 +3445,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); - if (!skip_groupnorm) { + if (!is_qrwkv) { // group norm with head_count groups cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens); cur = ggml_norm(ctx, cur, 64e-5f); @@ -9880,7 +9873,7 @@ struct llm_build_context { 1 ); - cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1], false, false)); + cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1])); ggml_build_forward_expand(gf, cur); ggml_build_forward_expand( gf, @@ -10003,7 +9996,7 @@ struct llm_build_context { ) ); - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv(), true, hparams.wkv_skip_groupnorm)); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv())); cb(ffn_inp, "ffn_inp", il); // feed-forward network