diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2d93f31fa..aa38164ec 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -509,7 +509,7 @@ extern "C" { GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, - GGML_OP_RWKV_WKV, + GGML_OP_RWKV_WKV6, GGML_OP_UNARY, @@ -1879,7 +1879,7 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); - GGML_API struct ggml_tensor * ggml_rwkv_wkv( + GGML_API struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, struct ggml_tensor * k, struct ggml_tensor * v, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b57f1b3b7..9ae59265e 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2313,8 +2313,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; - case GGML_OP_RWKV_WKV: - ggml_cuda_op_rwkv_wkv(ctx, dst); + case GGML_OP_RWKV_WKV6: + ggml_cuda_op_rwkv_wkv6(ctx, dst); break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); @@ -3147,7 +3147,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: return true; case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/rwkv-wkv.cu index 098e92d35..761a81d75 100644 --- a/ggml/src/ggml-cuda/rwkv-wkv.cu +++ b/ggml/src/ggml-cuda/rwkv-wkv.cu @@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const } } -void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * k_d = (const float *)dst->src[0]->data; const float * v_d = (const float *)dst->src[1]->data; const float * r_d = (const float *)dst->src[2]->data; diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/rwkv-wkv.cuh index 13795247f..a7124ee51 100644 --- a/ggml/src/ggml-cuda/rwkv-wkv.cuh +++ b/ggml/src/ggml-cuda/rwkv-wkv.cuh @@ -2,4 +2,4 @@ #define CUDA_WKV_BLOCK_SIZE 64 -void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 84f2c766b..fdc15b0dd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3169,7 +3169,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", - "rwkv_wkv(k, v, r, tf, td, s)", + "rwkv_wkv6(k, v, r, tf, td, s)", "unary(x)", @@ -7361,9 +7361,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -// ggml_rwkv_wkv +// ggml_rwkv_wkv6 -struct ggml_tensor * ggml_rwkv_wkv( +struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, struct ggml_tensor * k, struct ggml_tensor * v, @@ -7395,7 +7395,7 @@ struct ggml_tensor * ggml_rwkv_wkv( const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_RWKV_WKV; + result->op = GGML_OP_RWKV_WKV6; result->src[0] = k; result->src[1] = v; result->src[2] = r; @@ -16604,15 +16604,16 @@ static void ggml_compute_forward_add_rel_pos( } } -// ggml_compute_forward_rwkv_wkv +// ggml_compute_forward_rwkv_wkv6 -static void ggml_compute_forward_rwkv_wkv_f32( +static void ggml_compute_forward_rwkv_wkv6_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const size_t T = dst->src[1]->ne[3]; const size_t C = dst->ne[0]; const size_t H = dst->src[1]->ne[2]; const size_t n_seqs = dst->src[5]->ne[1]; + const size_t head_size = C / H; float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; @@ -16629,10 +16630,10 @@ static void ggml_compute_forward_rwkv_wkv_f32( float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - size_t t_stride = H * (C / H); + size_t t_stride = H * head_size; size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); + size_t h_stride_2d = head_size * head_size; // basically fused operations: // dst = r @ (time_faaaa * (k @ v) + state), @@ -16640,7 +16641,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( // recursive through each token for (size_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; - size_t state_offset = (C / H) * C * (t / (T / n_seqs)); + size_t state_offset = head_size * C * (t / (T / n_seqs)); float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; @@ -16649,7 +16650,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < C / H; i++) { + for (size_t i = 0; i < head_size; i++) { size_t t_h_i_offset = t_h_offset + i; size_t h_i_offset = h_offset + i; size_t h_2d_i_offset = h_2d_offset + i * h_stride; @@ -16660,7 +16661,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( // RWKV v6: different time_decay for each token. float time_decay_val = time_decay[t_h_i_offset]; - for (size_t j = 0; j < C / H; j ++) { + for (size_t j = 0; j < head_size; j ++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; @@ -16676,7 +16677,8 @@ static void ggml_compute_forward_rwkv_wkv_f32( } } -static void ggml_compute_forward_rwkv_wkv( + +static void ggml_compute_forward_rwkv_wkv6( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -16685,7 +16687,7 @@ static void ggml_compute_forward_rwkv_wkv( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_rwkv_wkv_f32(params, dst); + ggml_compute_forward_rwkv_wkv6_f32(params, dst); } break; default: { @@ -17437,9 +17439,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add_rel_pos(params, tensor); } break; - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: { - ggml_compute_forward_rwkv_wkv(params, tensor); + ggml_compute_forward_rwkv_wkv6(params, tensor); } break; case GGML_OP_MAP_UNARY: { @@ -18628,7 +18630,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GET_REL_POS: case GGML_OP_ADD_REL_POS: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -19278,7 +19280,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/src/llama.cpp b/src/llama.cpp index 3f534596e..265a81fbe 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7011,7 +7011,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, - {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7127,7 +7127,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); } break; - case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_WKV6: { // FIXME const int64_t S = 123; @@ -7140,7 +7140,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * tf = w; ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens); ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); @@ -10083,7 +10083,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( v = ggml_transpose(ctx, v); r = ggml_transpose(ctx, r); - struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e3ad79f0..7a3635238 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1613,8 +1613,8 @@ struct test_ssm_scan : public test_case { } }; -// GGML_OP_RWKV_WKV -struct test_rwkv_wkv : public test_case { +// GGML_OP_RWKV_WKV6 +struct test_rwkv_wkv6 : public test_case { const ggml_type type; const int64_t head_count; @@ -1626,7 +1626,7 @@ struct test_rwkv_wkv : public test_case { return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); } - test_rwkv_wkv(ggml_type type = GGML_TYPE_F32, + test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32, int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} @@ -1638,7 +1638,7 @@ struct test_rwkv_wkv : public test_case { ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector{ head_size, head_count }.data()); ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector{ 1, head_size, head_count, n_tokens }.data()); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); - ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s); + ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s); return out; } }; @@ -3498,10 +3498,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4)); - test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); + test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); #if 1 for (ggml_type type_a : base_types) {