From f0e1429d7fd56483ee8352c6bca40344a653f01a Mon Sep 17 00:00:00 2001 From: niansa Date: Fri, 30 Jun 2023 16:01:08 +0200 Subject: [PATCH] Implemented RMS_NORM --- ggml-vulkan.cpp | 99 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 90 insertions(+), 9 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 29c67e776..35d31157b 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -657,6 +657,23 @@ void ggml_vk_soft_max(kp::Sequence& seq, } +void ggml_vk_norm(kp::Sequence& seq, std::vector spirv, unsigned nth, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, + int64_t nrows) { + struct PushConstants { + uint64_t ne00, nb01; + float eps; + uint32_t inOff, outOff; + } pushConsts { + (uint64_t)ne00, (uint64_t)ne01, 1e-5f, inOff, outOff + }; + + seq.record(mgr.algorithm({in, out}, spirv, {(uint32_t)nrows, nth}, {}, {pushConsts})); +} + + static const std::string program_norm = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { @@ -681,6 +698,7 @@ void main() { for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { sum[gl_GlobalInvocationID.y] += in_[x+i00]; } + // reduce barrier(); memoryBarrierShared(); @@ -691,6 +709,7 @@ void main() { barrier(); memoryBarrierShared(); } + // broadcast if (gl_GlobalInvocationID.y == 0) { sum[0] /= float(pcs.ne00); @@ -711,6 +730,7 @@ void main() { for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { sum[gl_GlobalInvocationID.y] += out_[y+i00] * out_[y+i00]; } + // reduce barrier(); memoryBarrierShared(); @@ -721,6 +741,7 @@ void main() { barrier(); memoryBarrierShared(); } + // broadcast if (gl_GlobalInvocationID.y == 0) { sum[0] /= float(pcs.ne00); @@ -744,17 +765,74 @@ void ggml_vk_norm(kp::Sequence& seq, const static unsigned nth = 256; const static auto spirv = glsl_compile_source(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_norm, __func__); - struct PushConstants { - uint64_t ne00, nb01; - float eps; - uint32_t inOff, outOff; - } pushConsts { - (uint64_t)ne00, (uint64_t)ne01, 1e-5f, inOff, outOff - }; - - seq.record(mgr.algorithm({in, out}, spirv, {(uint32_t)nrows, nth}, {}, {pushConsts})); + ggml_vk_norm(seq, spirv, nth, in, inOff, out, outOff, ne00, ne01, nrows); } + +static const std::string program_rms_norm = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint64_t ne00; + uint64_t nb01; + float eps; + uint inOff; + uint outOff; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; +layout(binding = 1) buffer restrict tensorOut { float out_[]; }; + +shared float sum[nth]; + +void main() { + const uint x = gl_GlobalInvocationID.x; // Based from in_ + + // parallel sum + sum[gl_GlobalInvocationID.y] = 0.0; + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + sum[gl_GlobalInvocationID.y] += in_[x+i00] * in_[x+i00]; + } + + // reduce + barrier(); + memoryBarrierShared(); + for (uint i = nth/2; i > 0; i /= 2) { + if (gl_GlobalInvocationID.y < i) { + sum[gl_GlobalInvocationID.y] += sum[gl_GlobalInvocationID.y + i]; + } + barrier(); + memoryBarrierShared(); + } + + // broadcast + if (gl_GlobalInvocationID.y == 0) { + sum[0] /= float(pcs.ne00); + } + barrier(); + memoryBarrierShared(); + + const float scale = 1.0f/sqrt(sum[0] + pcs.eps); + + const uint y = gl_GlobalInvocationID.x; // Based from out_ + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + out_[y+i00] = in_[x+i00] * scale; + } +} +); + +void ggml_vk_rms_norm(kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, + int64_t nrows) { + const static unsigned nth = 256; + const static auto spirv = glsl_compile_source(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_rms_norm, __func__); + + ggml_vk_norm(seq, spirv, nth, in, inOff, out, outOff, ne00, ne01, nrows); +} + + static const std::string program_diag_mask_inf = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { @@ -1201,6 +1279,9 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph case GGML_OP_NORM: { ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); } break; + case GGML_OP_RMS_NORM: { + ggml_vk_rms_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); //GGML_ASSERT(false);