diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5c81d51bc..1eb205524 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -228,6 +228,9 @@ struct vk_device_struct { vk_pipeline pipeline_simpler_mul_mat_q6_k; vk_pipeline pipeline_simpler_mul_mat_q8_0; + vk_pipeline pipeline_simpler_soft_max_f16; + vk_pipeline pipeline_simpler_soft_max_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; @@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants { uint32_t nrows_x; }; +struct vk_op_simpler_soft_max_push_constants { + int32_t ne00; + int32_t ne01; + int32_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + int32_t mask; +}; + struct vk_op_argsort_push_constants { uint32_t ncols; uint32_t ncols_pad; @@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + if (ctx->device->simpler_shaders) { + if (src1 && src1->type == GGML_TYPE_F16) { + return ctx->device->pipeline_simpler_soft_max_f16; + } + return ctx->device->pipeline_simpler_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); switch (op) { + case GGML_OP_SOFT_MAX: + if (ctx->device->simpler_shaders) { + elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] }; + break; + } + // fall-through case GGML_OP_NORM: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: { const uint32_t nr = ggml_nrows(src0); @@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { - ncols, - src1 != nullptr ? nrows_y : (uint32_t)0, - scale, max_bias, - m0, m1, - n_head_log2, - nrows_x, - }, dryrun); + if (ctx->device->simpler_shaders) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + (int32_t) src0->ne[0], + (int32_t) src0->ne[1], + (int32_t) src0->ne[2], + scale, max_bias, + m0, m1, + n_head_log2, + src1 == nullptr ? 0 : 1, + }, dryrun); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); + } } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp new file mode 100644 index 000000000..d2707c504 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp @@ -0,0 +1,69 @@ +// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4) + +#version 450 + +#include "simpler_common.comp" + +layout(local_size_x = 32) in; + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; + +layout(push_constant) uniform PushConstants { + int ne00; + int ne01; + int ne02; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + int mask; +} pcs; + +void main() { + if (gl_SubgroupInvocationID > 31) + return; + + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; + + const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; + const uint psrc0 = extra_off; + const uint pmask = i01*pcs.ne00; + const uint pdst = extra_off; + + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + + // parallel max + float localMax = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); + } + float max_ = subgroupMax(localMax); + + // parallel sum + float localSum = 0.0f; + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); + localSum += exp_psrc0; + out_[pdst + i00] = exp_psrc0; + } + + const float sum = subgroupAdd(localSum); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + out_[pdst + i00] /= sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ec357977e..d616795e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -504,6 +504,9 @@ void process_shaders() { string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {}); string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {}); + string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}}); + string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); }