From ed14f0764ad94d5550016226e3c14c1c6e87ce35 Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 28 Jun 2023 10:15:23 +0200 Subject: [PATCH] Fixed ggml_vk_abmath row argument --- ggml-vulkan.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index b5b2dc5fc..15433d544 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -344,15 +344,17 @@ void main() { } ); -template +template void ggml_vk_abmath(kp::Sequence& seq, const std::shared_ptr& inA, uint32_t inAOff, const std::shared_ptr& inB, uint32_t inBOff, const std::shared_ptr& out, uint32_t outOff, uint32_t size, uint32_t row = 0) { + GGML_ASSERT(with_row?row:!row); + const static auto spirv = compileSource(program_source_head+ "#define MATH_OP "+std::string(1, mathOP)+"\n" - "#define ROW_OP "+(row?"% pcs.row":"")+'\n'+ + "#define ROW_OP "+(with_row?"% pcs.row":"")+'\n'+ program_abmath, __func__); struct PushConstants { @@ -369,9 +371,9 @@ void ggml_vk_add(Args&&... args) { return ggml_vk_abmath<'+'>(std::forward(args)...); } -template +template void ggml_vk_mul(Args&&... args) { - return ggml_vk_abmath<'*'>(std::forward(args)...); + return ggml_vk_abmath<'*', with_row>(std::forward(args)...); } @@ -589,7 +591,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { if (ggml_nelements(src1) == ne10) { // src1 is a row - ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ggml_nelements(dst)); + ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00); } else { ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); }