From de7d1823ed7e6c9054e10368ebe34e0c666af7b2 Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 28 Jun 2023 12:48:41 +0200 Subject: [PATCH] Implemented ggml_vk_soft_max --- ggml-vulkan.cpp | 144 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 127 insertions(+), 17 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 3c7beedde..1cc54d06f 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -212,15 +212,28 @@ std::vector getVecBlockQ4_0QS(T *x, unsigned nb, unsigned qk) { }; -static const std::string program_source_head = R"( -#version 450 +static const std::string program_source_head = R"(#version 450 + #extension GL_EXT_shader_explicit_arithmetic_types_float16: enable #extension GL_EXT_shader_explicit_arithmetic_types_int8: enable +#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable + #define QK4_0 32 #define QR4_0 2 #define QK4_1 32 + #define GELU_COEF_A 0.044715; #define SQRT_2_OVER_PI 0.79788456080286535587989211986876; + +#ifndef QK_K +#define QK_K 256 +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif )"; @@ -366,16 +379,6 @@ void ggml_vk_abmath(kp::Sequence& seq, seq.record(mgr.algorithm({inA, inB, out}, spirv, {size}, {}, {pushConsts})); } -template -void ggml_vk_add(Args&&... args) { - return ggml_vk_abmath<'+', with_row>(std::forward(args)...); -} - -template -void ggml_vk_mul(Args&&... args) { - return ggml_vk_abmath<'*', with_row>(std::forward(args)...); -} - static const std::string program_scale = MULTILINE_QUOTE( @@ -456,8 +459,8 @@ void ggml_vk_silu(Args&&... args) { static const std::string program_relu = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { - uint inAOff; uint inOff; + uint outOff; } pcs; layout(local_size_x = 1) in; @@ -482,8 +485,8 @@ void ggml_vk_relu(Args&&... args) { static const std::string program_gelu = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { - uint inAOff; uint inOff; + uint outOff; } pcs; layout(local_size_x = 1) in; @@ -506,6 +509,109 @@ void ggml_vk_gelu(Args&&... args) { } +static const std::string program_soft_max = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint64_t ne00; + uint64_t ne01; + uint64_t ne02; + uint inOff; + uint outOff; +} pcs; + +layout(local_size_x = nth) in; +layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; }; +layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; + +shared float buf[nth]; + +void main() { + const uint64_t i03 = uint64_t(gl_GlobalInvocationID.z); + const uint64_t i02 = uint64_t(gl_GlobalInvocationID.y); + const uint64_t i01 = uint64_t(gl_GlobalInvocationID.x); + + const uint extra_off = uint(i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00); + const uint in_off = pcs.inOff + extra_off; + const uint out_off = pcs.outOff + extra_off; + + // parallel max + buf[gl_LocalInvocationID.x] = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { + buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], in_[in_off + i00]); + } + + // reduce + barrier(); + memoryBarrierShared(); + for (uint i = nth/2; i > 0; i /= 2) { + if (gl_LocalInvocationID.x < i) { + buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], buf[gl_LocalInvocationID.x + i]); + } + barrier(); + memoryBarrierShared(); + } + + // broadcast (no effect?) + if (gl_LocalInvocationID.x == 0) { + buf[0] = buf[0]; // ??? + } + + barrier(); + memoryBarrierShared(); + + const float max_ = buf[0]; + + // parallel sum + buf[gl_LocalInvocationID.x] = 0.0; + for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { + buf[gl_LocalInvocationID.x] += exp(in_[in_off + i00] - max_); + } + + // reduce + barrier(); + memoryBarrierShared(); + for (uint i = nth/2; i > 0; i /= 2) { + if (gl_LocalInvocationID.x < i) { + buf[gl_LocalInvocationID.x] += buf[gl_LocalInvocationID.x + i]; + } + barrier(); + memoryBarrierShared(); + } + + // broadcast (no effect?) + if (gl_LocalInvocationID.x == 0) { + buf[0] = buf[0]; // ??? + } + + barrier(); + memoryBarrierShared(); + + const float sum = buf[0]; + + for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { + out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum; + } +} +); + +void ggml_vk_soft_max(kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03) { + const static unsigned nth = 32; + const static auto spirv = compileSource(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_soft_max, __func__); + + struct PushConstants { + int64_t ne00, ne01, ne02; + uint32_t inOff, outOff; + } pushConsts { + ne00, ne01, ne02, inOff, outOff + }; + + seq.record(mgr.algorithm({in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts})); +} + + void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { printf("%s: evaluating graph\n", __func__); @@ -585,15 +691,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_ADD: { - ggml_vk_add(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); + ggml_vk_abmath<'+'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); } break; case GGML_OP_MUL: { if (ggml_nelements(src1) == ne10) { // src1 is a row - ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00); + ggml_vk_abmath<'*', true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00); } else { - ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); + ggml_vk_abmath<'*'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); } } break; case GGML_OP_SCALE: @@ -613,6 +719,10 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_gelu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst)); } break; + case GGML_OP_SOFT_MAX: + { + ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03); + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); //GGML_ASSERT(false);