From c7b8ab73de3d4e968fb397c6e1923d6b254c8a0b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 14 Nov 2024 20:18:57 -0600 Subject: [PATCH] vulkan: Optimize soft_max Large soft_max could already saturate memory, but small/medium sizes were pretty slow. The bulk of the gains for them comes from using a smaller workgroup size, and making the workgroup size match the subgroup size also makes the barriers much cheaper. Cache some values in locals to avoid refetching/recomputing. And stamp out a few "template instantiations" so smaller cases will fully unroll. Add a missing early return for OOB rows. This happens when there are more than 512 rows and the dispatch is 512 x H. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +- .../ggml-vulkan/vulkan-shaders/soft_max.comp | 123 ++++++++++++++---- tests/test-backend-ops.cpp | 8 ++ 3 files changed, 111 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c02c35665..7b13cc10b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -388,6 +388,7 @@ struct vk_op_soft_max_push_constants { float m0; float m1; uint32_t n_head_log2; + uint32_t nrows_x; }; struct vk_op_argsort_push_constants { @@ -1496,8 +1497,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4581,6 +4582,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, scale, max_bias, m0, m1, n_head_log2, + nrows_x, }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 0bd51ecab..a83b9a405 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -1,6 +1,7 @@ #version 450 -#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_control_flow_attributes : enable layout (push_constant) uniform parameter { @@ -11,14 +12,13 @@ layout (push_constant) uniform parameter float m0; float m1; uint n_head_log2; + uint nrows_x; } p; #include "types.comp" -#extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; @@ -26,11 +26,18 @@ layout (binding = 2) buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE vals[BLOCK_SIZE]; -void main() { +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint rowy = rowx % p.KY; + if (rowx >= p.nrows_x) { + return; + } + float slope = 1.0f; // ALiBi @@ -46,19 +53,37 @@ void main() { // Find max FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { const uint col = col0 + tid; - if (col >= p.KX) { - break; + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; } - max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f))); - } - vals[tid] = max_val; + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy * p.KX + col]; + } + FLOAT_TYPE v = a * p.scale + slope * b; + + max_val = max(max_val, v); + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { vals[tid] = max(vals[tid], vals[tid + s]); } @@ -68,39 +93,89 @@ void main() { max_val = vals[0]; barrier(); - // Sum up values - vals[tid] = FLOAT_TYPE(0.0f); + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { const uint col = col0 + tid; if (col >= p.KX) { break; } + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. const uint i = rowx * p.KX + col; - const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); - vals[tid] += val; - data_d[i] = D_TYPE(val); + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } } + // reduce across the workgroup + vals[tid] = sum; barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { vals[tid] += vals[tid + s]; } barrier(); } + sum = vals[0]; - const D_TYPE divisor = D_TYPE(vals[0]); + FLOAT_TYPE rcpdivisor = 1.0/sum; - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { const uint col = col0 + tid; if (col >= p.KX) { - break; + continue; } - data_d[rowx*p.KX + col] /= divisor; + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + switch (num_blocks) { + case 1: + soft_max(1); + break; + case 2: + soft_max(2); + break; + case 3: + soft_max(3); + break; + case 4: + soft_max(4); + break; + case 5: + case 6: + case 7: + case 8: + soft_max(8); + break; + case 16: + soft_max(16); + break; + default: + soft_max(num_blocks); + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6618d03d1..24a5c03f6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3790,6 +3790,14 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f)); + for (int bs : {1, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {