From bae9a58f5df8fe4e6fe7c8b71a4383c15a5f1cda Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Fri, 7 Feb 2025 14:57:03 +0100 Subject: [PATCH] vulkan: enable the use of simpler softmax shaders Even though the regular softmax shaders successfully pass test-backend-ops with Apple GPUs, running long inference tests has shown the models end derailing with softmax OPs being the root cause. With this commit, we use simpler softmax shaders borrowed from the Kompute backend (which are basically reimplementations of the Metal shaders) on certain GPUs know to have problem with the regular ones. Signed-off-by: Sergio Lopez --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 60 +++++++++++++--- .../vulkan-shaders/simpler_soft_max.comp | 69 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + 3 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5c81d51bc..1eb205524 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -228,6 +228,9 @@ struct vk_device_struct { vk_pipeline pipeline_simpler_mul_mat_q6_k; vk_pipeline pipeline_simpler_mul_mat_q8_0; + vk_pipeline pipeline_simpler_soft_max_f16; + vk_pipeline pipeline_simpler_soft_max_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; @@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants { uint32_t nrows_x; }; +struct vk_op_simpler_soft_max_push_constants { + int32_t ne00; + int32_t ne01; + int32_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + int32_t mask; +}; + struct vk_op_argsort_push_constants { uint32_t ncols; uint32_t ncols_pad; @@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + if (ctx->device->simpler_shaders) { + if (src1 && src1->type == GGML_TYPE_F16) { + return ctx->device->pipeline_simpler_soft_max_f16; + } + return ctx->device->pipeline_simpler_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); switch (op) { + case GGML_OP_SOFT_MAX: + if (ctx->device->simpler_shaders) { + elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] }; + break; + } + // fall-through case GGML_OP_NORM: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: { const uint32_t nr = ggml_nrows(src0); @@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { - ncols, - src1 != nullptr ? nrows_y : (uint32_t)0, - scale, max_bias, - m0, m1, - n_head_log2, - nrows_x, - }, dryrun); + if (ctx->device->simpler_shaders) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + (int32_t) src0->ne[0], + (int32_t) src0->ne[1], + (int32_t) src0->ne[2], + scale, max_bias, + m0, m1, + n_head_log2, + src1 == nullptr ? 0 : 1, + }, dryrun); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); + } } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp new file mode 100644 index 000000000..d2707c504 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp @@ -0,0 +1,69 @@ +// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4) + +#version 450 + +#include "simpler_common.comp" + +layout(local_size_x = 32) in; + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; + +layout(push_constant) uniform PushConstants { + int ne00; + int ne01; + int ne02; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + int mask; +} pcs; + +void main() { + if (gl_SubgroupInvocationID > 31) + return; + + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; + + const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; + const uint psrc0 = extra_off; + const uint pmask = i01*pcs.ne00; + const uint pdst = extra_off; + + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + + // parallel max + float localMax = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); + } + float max_ = subgroupMax(localMax); + + // parallel sum + float localSum = 0.0f; + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); + localSum += exp_psrc0; + out_[pdst + i00] = exp_psrc0; + } + + const float sum = subgroupAdd(localSum); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + out_[pdst + i00] /= sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ec357977e..d616795e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -504,6 +504,9 @@ void process_shaders() { string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {}); string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {}); + string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}}); + string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); }