diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 4d9c458df..29c67e776 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -482,7 +482,7 @@ layout(push_constant) uniform PushConstants { } pcs; layout(local_size_x = 1) in; -layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; }; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; void main() { @@ -509,7 +509,7 @@ layout(push_constant) uniform PushConstants { } pcs; layout(local_size_x = 1) in; -layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; }; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; void main() { @@ -535,7 +535,7 @@ layout(push_constant) uniform PushConstants { } pcs; layout(local_size_x = 1) in; -layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; }; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; void main() { @@ -565,7 +565,7 @@ layout(push_constant) uniform PushConstants { } pcs; layout(local_size_x = nth) in; -layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; }; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; shared float buf[nth]; @@ -657,6 +657,104 @@ void ggml_vk_soft_max(kp::Sequence& seq, } +static const std::string program_norm = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint64_t ne00; + uint64_t nb01; + float eps; + uint inOff; + uint outOff; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; +layout(binding = 1) buffer restrict tensorOut { float out_[]; }; + +shared float sum[nth]; + +void main() { + const uint x = gl_GlobalInvocationID.x; // Based from in_ + // MEAN + // parallel sum + sum[gl_GlobalInvocationID.y] = 0.0; + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + sum[gl_GlobalInvocationID.y] += in_[x+i00]; + } + // reduce + barrier(); + memoryBarrierShared(); + for (uint i = nth/2; i > 0; i /= 2) { + if (gl_GlobalInvocationID.y < i) { + sum[gl_GlobalInvocationID.y] += sum[gl_GlobalInvocationID.y + i]; + } + barrier(); + memoryBarrierShared(); + } + // broadcast + if (gl_GlobalInvocationID.y == 0) { + sum[0] /= float(pcs.ne00); + } + barrier(); + memoryBarrierShared(); + const float mean = sum[0]; + + // recenter + const uint y = gl_GlobalInvocationID.x; // Based from out_ + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + out_[y+i00] = in_[x+i00] - mean; + } + + // VARIANCE + // parallel sum + sum[gl_GlobalInvocationID.y] = 0.0; + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + sum[gl_GlobalInvocationID.y] += out_[y+i00] * out_[y+i00]; + } + // reduce + barrier(); + memoryBarrierShared(); + for (uint i = nth/2; i > 0; i /= 2) { + if (gl_GlobalInvocationID.y < i) { + sum[gl_GlobalInvocationID.y] += sum[gl_GlobalInvocationID.y + i]; + } + barrier(); + memoryBarrierShared(); + } + // broadcast + if (gl_GlobalInvocationID.y == 0) { + sum[0] /= float(pcs.ne00); + } + barrier(); + memoryBarrierShared(); + const float variance = sum[0]; + + const float scale = 1.0/sqrt(variance + pcs.eps); + for (uint i00 = gl_GlobalInvocationID.y; i00 < pcs.ne00; i00 += nth) { + out_[y+i00] *= scale; + } +} +); + +void ggml_vk_norm(kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, + int64_t nrows) { + const static unsigned nth = 256; + const static auto spirv = glsl_compile_source(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_norm, __func__); + + struct PushConstants { + uint64_t ne00, nb01; + float eps; + uint32_t inOff, outOff; + } pushConsts { + (uint64_t)ne00, (uint64_t)ne01, 1e-5f, inOff, outOff + }; + + seq.record(mgr.algorithm({in, out}, spirv, {(uint32_t)nrows, nth}, {}, {pushConsts})); +} + static const std::string program_diag_mask_inf = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { @@ -1100,6 +1198,9 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph break; } } + case GGML_OP_NORM: { + ggml_vk_norm(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ggml_nrows(src0)); + } break; default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); //GGML_ASSERT(false);