From 0899adf86ed9765cc4cc349fb0a980e1bf77dd63 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Mon, 22 Jan 2024 14:16:10 -0500 Subject: [PATCH] kompute : fix get_rows dispatch -> 4 less failures --- ggml-kompute.cpp | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 58c76347e..86bd0d78b 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1089,15 +1089,18 @@ static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_get_rows(const std::vector& spirv, - unsigned element_size, unsigned qk, - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t nb01, int32_t nb1, - uint32_t size) { +static void ggml_vk_get_rows( + const std::vector& spirv, + const char * suffix, + unsigned element_size, unsigned qk, + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t nb01, int32_t nb1, + uint32_t size +) { GGML_ASSERT(nb01%element_size == 0); GGML_ASSERT(nb1%sizeof(float) == 0); if (qk) GGML_ASSERT(ne00%qk == 0); @@ -1110,11 +1113,12 @@ static void ggml_vk_get_rows(const std::vector& spirv, ne00, nb01, nb1 }; + auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) + if (!komputeManager()->hasAlgorithm(name)) { s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); + } else { + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({inA, inB, out}); s_algo->setWorkgroup({size}); s_algo->setPushConstants({pushConsts}); @@ -1128,7 +1132,7 @@ static void ggml_vk_get_rows_f16(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, kp::shader_data::op_getrows_f16_comp_spv_len); - ggml_vk_get_rows(spirv, sizeof(half), 0, std::forward(args)...); + ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward(args)...); } template @@ -1136,7 +1140,7 @@ static void ggml_vk_get_rows_q4_0(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv, kp::shader_data::op_getrows_q4_0_comp_spv_len); - ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_0, std::forward(args)...); + ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward(args)...); } template @@ -1144,14 +1148,14 @@ static void ggml_vk_get_rows_q4_1(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv, kp::shader_data::op_getrows_q4_1_comp_spv_len); - ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_1, std::forward(args)...); + ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward(args)...); } template static void ggml_vk_get_rows_q6_k(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv, kp::shader_data::op_getrows_q6_k_comp_spv_len); - ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward(args)...); + ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward(args)...); } static void ggml_vk_rope(