From 60bbd4ebf174986fb1b918310ab65074847d79ed Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 13 Dec 2024 17:43:08 +0800 Subject: [PATCH] Apply code format changes Signed-off-by: Molly Sophia --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 97 ++++++------------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp | 75 ++++++-------- 3 files changed, 60 insertions(+), 114 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index da11e88cd..4c0fb4d46 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -524,15 +524,13 @@ struct vk_op_pool2d_push_constants { int32_t p0; int32_t p1; }; - struct vk_op_rwkv_wkv6_push_constants { - uint32_t B; // Batch size (原n_seqs) - uint32_t T; // Sequence length - uint32_t C; // Total channels - uint32_t H; // Number of heads (原HEADS) + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; }; - // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1952,19 +1950,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline( - device, - device->pipeline_rwkv_wkv6_f32, - "rwkv_wkv6_f32", - rwkv_wkv6_f32_len, - rwkv_wkv6_f32_data, - "main", - 7, - sizeof(vk_op_rwkv_wkv6_push_constants), - {1, 1, 1}, // work group - {device->subgroup_size}, - 1 - ); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); for (auto &c : compiles) { c.wait(); @@ -5348,28 +5334,14 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * k = dst->src[0]; + const ggml_tensor * v = dst->src[1]; + const ggml_tensor * r = dst->src[2]; + const ggml_tensor * tf = dst->src[3]; + const ggml_tensor * td = dst->src[4]; + const ggml_tensor * state = dst->src[5]; - -template -static void ggml_vk_op_f32_rwkv6( - ggml_backend_vk_context * ctx, - vk_context& subctx, - ggml_tensor * dst, - const PC&& pc, - bool dryrun = false) { - - // Get source tensors - const ggml_tensor * k = dst->src[0]; // keys - const ggml_tensor * v = dst->src[1]; // values - const ggml_tensor * r = dst->src[2]; // reset gates - const ggml_tensor * tf = dst->src[3]; // time first - const ggml_tensor * td = dst->src[4]; // time decay - const ggml_tensor * state = dst->src[5]; // states - - VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", " - << tf << ", " << td << ", " << state << ", " << dst << ")"); - - // Verify input types GGML_ASSERT(!ggml_is_quantized(k->type)); GGML_ASSERT(!ggml_is_quantized(v->type)); GGML_ASSERT(!ggml_is_quantized(r->type)); @@ -5378,7 +5350,6 @@ static void ggml_vk_op_f32_rwkv6( GGML_ASSERT(!ggml_is_quantized(state->type)); GGML_ASSERT(dst->buffer != nullptr); - // Get pipeline vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); GGML_ASSERT(pipeline != nullptr); @@ -5387,7 +5358,6 @@ static void ggml_vk_op_f32_rwkv6( return; } - // Get buffer contexts ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; @@ -5396,7 +5366,6 @@ static void ggml_vk_op_f32_rwkv6( ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; - // Get device buffers vk_buffer d_D = dst_buf_ctx->dev_buffer; vk_buffer d_K = k_buf_ctx->dev_buffer; vk_buffer d_V = v_buf_ctx->dev_buffer; @@ -5405,7 +5374,6 @@ static void ggml_vk_op_f32_rwkv6( vk_buffer d_TD = td_buf_ctx->dev_buffer; vk_buffer d_State = state_buf_ctx->dev_buffer; - // Calculate buffer offsets const uint64_t k_offset = vk_tensor_offset(k); const uint64_t v_offset = vk_tensor_offset(v); const uint64_t r_offset = vk_tensor_offset(r); @@ -5414,7 +5382,6 @@ static void ggml_vk_op_f32_rwkv6( const uint64_t state_offset = vk_tensor_offset(state); const uint64_t dst_offset = vk_tensor_offset(dst); - // Calculate buffer sizes const uint64_t k_size = ggml_nbytes(k); const uint64_t v_size = ggml_nbytes(v); const uint64_t r_size = ggml_nbytes(r); @@ -5423,14 +5390,12 @@ static void ggml_vk_op_f32_rwkv6( const uint64_t state_size = ggml_nbytes(state); const uint64_t dst_size = ggml_nbytes(dst); - // Set work elements based on tensor dimensions std::array elements = { - (uint32_t)(pc.B*pc.H), // B * H workgroups - 1, // 每个workgroup 64个线程 + (uint32_t)(pc.B * pc.H), + 1, 1 }; - // Synchronize buffers and dispatch compute pipeline ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_K, k_offset, k_size }, @@ -5440,35 +5405,27 @@ static void ggml_vk_op_f32_rwkv6( vk_subbuffer{ d_TD, td_offset, td_size }, vk_subbuffer{ d_State, state_offset, state_size }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(PC), &pc, elements); + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); } -static void ggml_vk_rwkv_wkv6( - ggml_backend_vk_context * ctx, - vk_context& subctx, - ggml_tensor * dst, - bool dryrun = false) { - - // Extract dimensions from tensors - const size_t T = dst->src[0]->ne[3]; // Sequence length - const size_t C = dst->ne[0]; // Channel dimension - const size_t HEADS = dst->src[0]->ne[2]; // Number of heads - const size_t n_seqs = dst->src[5]->ne[1]; // Batch size +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[3]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; - // Call implementation with push constants - ggml_vk_op_f32_rwkv6( + ggml_vk_op_f32_rwkv6( ctx, subctx, dst, { - (uint32_t)n_seqs, // B - (uint32_t)T, // T - (uint32_t)C, // C - (uint32_t)HEADS, // H + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, }, dryrun ); } - static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -8344,10 +8301,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); - } else if (tensor->op == GGML_OP_RWKV_WKV6) { + } else if (tensor->op == GGML_OP_RWKV_WKV6) { tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5]); - } + } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index eff60f3c3..7a0d7285d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -479,7 +479,7 @@ void process_shaders() { string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}})); + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); for (auto &c : compiles) { c.wait(); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp index 6465f2da9..8beb7ff6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -1,96 +1,85 @@ #version 450 - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { - uint B; // Batch size - uint T; // Sequence length - uint C; // Total number of channels - uint H; // Number of heads + uint B; + uint T; + uint C; + uint H; }; -layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; }; -layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; }; -layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; }; -layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; }; -layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; }; -layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; }; -layout(set = 0, binding = 6) buffer DstBuf { float dst[]; }; +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; -shared float _k[64], _r[64], _tf[64], _td[64]; +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; void main() { - const uint head_size = 64; + const uint head_size = BLOCK_SIZE; const uint batch_id = gl_WorkGroupID.x / H; const uint head_id = gl_WorkGroupID.x % H; const uint tid = gl_LocalInvocationID.x; - + const uint state_size = C * head_size; const uint n_seq_tokens = T / B; if (tid >= head_size || batch_id >= B || head_id >= H) { return; } - - // Load state - float state[64]; // Use fixed size matching head_size + + A_TYPE state[BLOCK_SIZE]; for (uint i = 0; i < head_size; i++) { - state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + i * head_size + tid]; } - - _k[tid] = 0.0; - _r[tid] = 0.0; - _td[tid] = 0.0; + barrier(); _tf[tid] = tf[head_id * head_size + tid]; barrier(); - - // Main loop const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; - + for (uint t = start_t; t < end_t; t += C) { barrier(); _k[tid] = k[t]; _r[tid] = r[t]; _td[tid] = td[t]; barrier(); - - const float v_val = v[t]; - float y = 0.0; - + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + for (uint j = 0; j < head_size; j += 4) { - // Load values in blocks of 4 vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); - - // Compute kv products + vec4 kv = k_vec * v_val; - - // Accumulate results + vec4 temp = tf_vec * kv + s_vec; y += dot(r_vec, temp); - - // Update state + s_vec = s_vec * td_vec + kv; state[j] = s_vec.x; state[j+1] = s_vec.y; state[j+2] = s_vec.z; state[j+3] = s_vec.w; } - + dst[t] = y; } - - // Write back state + for (uint i = 0; i < head_size; i++) { - dst[T * C + batch_id * state_size + head_id * head_size * head_size + dst[T * C + batch_id * state_size + head_id * head_size * head_size + i * head_size + tid] = state[i]; } -} \ No newline at end of file +}