From 4651f5e2f29b32e24c69c511d0bacb14d29e6008 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 2 Nov 2024 01:45:27 +1100 Subject: [PATCH] rwkv_wkv6 vulkan shader --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 165 +++++++++++++++++- .../ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp | 96 ++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8ae58ee2..e103e67f7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -240,6 +240,7 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -523,6 +524,15 @@ struct vk_op_pool2d_push_constants { int32_t p0; int32_t p1; }; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; // Batch size (原n_seqs) + uint32_t T; // Sequence length + uint32_t C; // Total channels + uint32_t H; // Number of heads (原HEADS) +}; + + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1942,6 +1952,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline( + device, + device->pipeline_rwkv_wkv6_f32, + "rwkv_wkv6_f32", + rwkv_wkv6_f32_len, + rwkv_wkv6_f32_data, + "main", + 7, + sizeof(vk_op_rwkv_wkv6_push_constants), + {64, 1, 1}, // work group + {device->subgroup_size}, + 1 + ); + for (auto &c : compiles) { c.wait(); } @@ -4917,6 +4941,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_pool2d_f32; } return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -5319,6 +5348,127 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } + + +template +static void ggml_vk_op_f32_rwkv6( + ggml_backend_vk_context * ctx, + vk_context& subctx, + ggml_tensor * dst, + const PC&& pc, + bool dryrun = false) { + + // Get source tensors + const ggml_tensor * k = dst->src[0]; // keys + const ggml_tensor * v = dst->src[1]; // values + const ggml_tensor * r = dst->src[2]; // reset gates + const ggml_tensor * tf = dst->src[3]; // time first + const ggml_tensor * td = dst->src[4]; // time decay + const ggml_tensor * state = dst->src[5]; // states + + VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", " + << tf << ", " << td << ", " << state << ", " << dst << ")"); + + // Verify input types + GGML_ASSERT(!ggml_is_quantized(k->type)); + GGML_ASSERT(!ggml_is_quantized(v->type)); + GGML_ASSERT(!ggml_is_quantized(r->type)); + GGML_ASSERT(!ggml_is_quantized(tf->type)); + GGML_ASSERT(!ggml_is_quantized(td->type)); + GGML_ASSERT(!ggml_is_quantized(state->type)); + GGML_ASSERT(dst->buffer != nullptr); + + // Get pipeline + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + // Get buffer contexts + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; + ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; + ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; + ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + + // Get device buffers + vk_buffer d_D = dst_buf_ctx->dev_buffer; + vk_buffer d_K = k_buf_ctx->dev_buffer; + vk_buffer d_V = v_buf_ctx->dev_buffer; + vk_buffer d_R = r_buf_ctx->dev_buffer; + vk_buffer d_TF = tf_buf_ctx->dev_buffer; + vk_buffer d_TD = td_buf_ctx->dev_buffer; + vk_buffer d_State = state_buf_ctx->dev_buffer; + + // Calculate buffer offsets + const uint64_t k_offset = vk_tensor_offset(k); + const uint64_t v_offset = vk_tensor_offset(v); + const uint64_t r_offset = vk_tensor_offset(r); + const uint64_t tf_offset = vk_tensor_offset(tf); + const uint64_t td_offset = vk_tensor_offset(td); + const uint64_t state_offset = vk_tensor_offset(state); + const uint64_t dst_offset = vk_tensor_offset(dst); + + // Calculate buffer sizes + const uint64_t k_size = ggml_nbytes(k); + const uint64_t v_size = ggml_nbytes(v); + const uint64_t r_size = ggml_nbytes(r); + const uint64_t tf_size = ggml_nbytes(tf); + const uint64_t td_size = ggml_nbytes(td); + const uint64_t state_size = ggml_nbytes(state); + const uint64_t dst_size = ggml_nbytes(dst); + + // Set work elements based on tensor dimensions + std::array elements = { + (uint32_t)(pc.B*pc.H), // B * H workgroups + 1, // 每个workgroup 64个线程 + 1 + }; + + // Synchronize buffers and dispatch compute pipeline + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_K, k_offset, k_size }, + vk_subbuffer{ d_V, v_offset, v_size }, + vk_subbuffer{ d_R, r_offset, r_size }, + vk_subbuffer{ d_TF, tf_offset, tf_size }, + vk_subbuffer{ d_TD, td_offset, td_size }, + vk_subbuffer{ d_State, state_offset, state_size }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(PC), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6( + ggml_backend_vk_context * ctx, + vk_context& subctx, + ggml_tensor * dst, + bool dryrun = false) { + + // Extract dimensions from tensors + const size_t T = dst->src[0]->ne[3]; // Sequence length + const size_t C = dst->ne[0]; // Channel dimension + const size_t HEADS = dst->src[0]->ne[2]; // Number of heads + const size_t n_seqs = dst->src[5]->ne[1]; // Batch size + + // Call implementation with push constants + ggml_vk_op_f32_rwkv6( + ctx, subctx, dst, + { + (uint32_t)n_seqs, // B + (uint32_t)T, // T + (uint32_t)C, // C + (uint32_t)HEADS, // H + }, + dryrun + ); +} + + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -6464,6 +6614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: break; @@ -6663,6 +6814,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_FLASH_ATTN_EXT: ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; default: return false; @@ -6743,6 +6899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: buf = tensor->buffer; @@ -7610,6 +7767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: return true; default: @@ -8186,7 +8344,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); - } else { + } + // else if (tensor->op == GGML_OP_RWKV_WKV6) { + // tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + // tensor->src[4], tensor->src[5]); + // } + else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp new file mode 100644 index 000000000..6465f2da9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp @@ -0,0 +1,96 @@ +#version 450 + + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; // Batch size + uint T; // Sequence length + uint C; // Total number of channels + uint H; // Number of heads +}; + +layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; }; +layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; }; +layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; }; +layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; }; +layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; }; +layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; }; +layout(set = 0, binding = 6) buffer DstBuf { float dst[]; }; + +shared float _k[64], _r[64], _tf[64], _td[64]; + +void main() { + const uint head_size = 64; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (tid >= head_size || batch_id >= B || head_id >= H) { + return; + } + + // Load state + float state[64]; // Use fixed size matching head_size + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + _k[tid] = 0.0; + _r[tid] = 0.0; + _td[tid] = 0.0; + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + + // Main loop + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const float v_val = v[t]; + float y = 0.0; + + for (uint j = 0; j < head_size; j += 4) { + // Load values in blocks of 4 + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute kv products + vec4 kv = k_vec * v_val; + + // Accumulate results + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + // Update state + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + // Write back state + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} \ No newline at end of file