diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index af697b221..45502ab5a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -287,13 +287,11 @@ layout(push_constant) uniform PushConstants { uint row; } pcs; - layout(local_size_x = 1) in; layout(binding = 0) buffer tensorInA { float inA[]; }; layout(binding = 1) buffer tensorInB { float inB[]; }; layout(binding = 2) buffer tensorOut { float out_[]; }; - void main() { const int i = int(gl_GlobalInvocationID.x); @@ -302,7 +300,8 @@ void main() { ); template -void ggml_vk_abmath(const std::shared_ptr& inA, uint32_t inAOff, +void ggml_vk_abmath(kp::Sequence& seq, + const std::shared_ptr& inA, uint32_t inAOff, const std::shared_ptr& inB, uint32_t inBOff, const std::shared_ptr& out, uint32_t outOff, uint32_t row = 0) { @@ -317,8 +316,7 @@ void ggml_vk_abmath(const std::shared_ptr& inA, uint32_t inAOff, inAOff, inBOff, outOff, row }; - mgr.sequence() - ->eval(mgr.algorithm({inA, inB, out}, spirv, {std::min(inA->size()-inAOff, inB->size()-inBOff)}, {}, {pushConsts})); + seq.record(mgr.algorithm({inA, inB, out}, spirv, {std::min(inA->size()-inAOff, inB->size()-inBOff)}, {}, {pushConsts})); } template @@ -332,6 +330,42 @@ void ggml_vk_mul(Args&&... args) { } +static const std::string program_scale = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint inAOff; + uint inOff; + float scale; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer tensorInA { float in_[]; }; +layout(binding = 1) buffer tensorOut { float out_[]; }; + +void main() { + const int i = int(gl_GlobalInvocationID.x); + + out_[pcs.outOff+i] = in_[pcs.inOff+i] * pcs.scale; +} +); + +void ggml_vk_scale(kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + float scale) { + const static auto spirv = compileSource(program_source_head+program_scale); + + struct PushConstants { + uint32_t inOff, outOff; + float scale; + } pushConsts { + inOff, outOff, scale + }; + + seq.record(mgr.algorithm({in, out}, spirv, {in->size()-inOff}, {}, {pushConsts})); +} + + void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { printf("%s: evaluating graph\n", __func__); @@ -413,15 +447,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_ADD: { - ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); + ggml_vk_add(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); } break; case GGML_OP_MUL: { if (ggml_nelements(src1) == ne10) { // src1 is a row - ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00); + ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00); } else { - ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); + ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); } } break; }