From 856b7589e9661507ff256b401f93c95da3173f2e Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 5 Jul 2023 13:34:01 +0200 Subject: [PATCH] Optimized ggml_vk_mul_mat_f16 argument count --- ggml-vulkan.cpp | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 5f1b8d43a..6aab3ddae 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -891,13 +891,8 @@ layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; layout (push_constant) uniform parameter { int64_t ne00; - int64_t ne01; - uint64_t nb00; uint64_t nb01; uint64_t nb02; - int64_t ne10; - int64_t ne11; - uint64_t nb10; uint64_t nb11; uint64_t nb12; int64_t ne0; @@ -945,21 +940,20 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq, const std::shared_ptr& inB, uint32_t inBOff, const std::shared_ptr& out, uint32_t outOff, int64_t ne00, int64_t ne01, - uint64_t nb00, uint64_t nb01, uint64_t nb02, - int64_t ne10, int64_t ne11, int64_t ne12, - uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb01, uint64_t nb02, + int64_t ne11, int64_t ne12, + uint64_t nb11, uint64_t nb12, int64_t ne0, int64_t ne1) { const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__); struct PushConstants { - int64_t ne00, ne01; - uint64_t nb00, nb01, nb02; - int64_t ne10, ne11; - uint64_t nb10, nb11, nb12; + int64_t ne00; + uint64_t nb01, nb02; + uint64_t nb11, nb12; int64_t ne0, ne1; uint32_t inAOff, inBOff, outOff; } pushConsts { - ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1, + ne00, nb01, nb02, nb11, nb12, ne0, ne1, inAOff, inBOff, outOff }; @@ -1085,7 +1079,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1); + ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1); break; } }