diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 76280501f..163b0a29a 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -743,22 +743,25 @@ static void ggml_vk_scale(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_xxlu(const std::vector& spirv, kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - uint32_t size) { +static void ggml_vk_xxlu( + const std::vector& spirv, const char * suffix, kp::Sequence& seq, + const std::shared_ptr& in, + const std::shared_ptr& out, + uint32_t inOff, uint32_t outOff, + uint32_t size +) { struct PushConstants { uint32_t inOff, outOff; } const pushConsts { safe_divide(inOff, 4), safe_divide(outOff, 4), }; + auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) + if (!komputeManager()->hasAlgorithm(name)) { s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); + } else { + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({in, out}); s_algo->setWorkgroup({size}); s_algo->setPushConstants({pushConsts}); @@ -772,7 +775,7 @@ static void ggml_vk_silu(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv, kp::shader_data::op_silu_comp_spv_len); - ggml_vk_xxlu(spirv, std::forward(args)...); + ggml_vk_xxlu(spirv, "silu", std::forward(args)...); } template @@ -780,7 +783,7 @@ static void ggml_vk_relu(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv, kp::shader_data::op_relu_comp_spv_len); - ggml_vk_xxlu(spirv, std::forward(args)...); + ggml_vk_xxlu(spirv, "relu", std::forward(args)...); } template @@ -788,7 +791,7 @@ static void ggml_vk_gelu(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv, kp::shader_data::op_gelu_comp_spv_len); - ggml_vk_xxlu(spirv, std::forward(args)...); + ggml_vk_xxlu(spirv, "gelu", std::forward(args)...); } static void ggml_vk_soft_max(kp::Sequence& seq, @@ -823,12 +826,14 @@ static void ggml_vk_soft_max(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_norm_(const std::vector& spirv, kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - int32_t ne00, int32_t nb01, - int32_t nrows, float epsilon) { +static void ggml_vk_norm_( + const std::vector& spirv, const char * suffix, kp::Sequence& seq, + const std::shared_ptr& in, + const std::shared_ptr& out, + uint32_t inOff, uint32_t outOff, + int32_t ne00, int32_t nb01, + int32_t nrows, float epsilon +) { GGML_ASSERT(nb01%sizeof(float) == 0); GGML_ASSERT(ne00%sizeof(float) == 0); @@ -841,11 +846,12 @@ static void ggml_vk_norm_(const std::vector& spirv, kp::Sequence& seq, (uint32_t)ne00, (uint32_t)nb01, epsilon }; + auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { + if (!komputeManager()->hasAlgorithm(name)) { s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts}); } else { - s_algo = komputeManager()->getAlgorithm(__func__); + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({in, out}); s_algo->setWorkgroup({(uint32_t)nrows}); s_algo->setPushConstants({pushConsts}); @@ -859,7 +865,7 @@ static void ggml_vk_norm(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv, kp::shader_data::op_norm_comp_spv_len); - ggml_vk_norm_(spirv, std::forward(args)...); + ggml_vk_norm_(spirv, "norm", std::forward(args)...); } template @@ -867,7 +873,7 @@ static void ggml_vk_rms_norm(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv, kp::shader_data::op_rmsnorm_comp_spv_len); - ggml_vk_norm_(spirv, std::forward(args)...); + ggml_vk_norm_(spirv, "rms", std::forward(args)...); } static void ggml_vk_diag_mask_inf(kp::Sequence& seq, @@ -1029,13 +1035,15 @@ static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_mul_mat_q4_x(const std::vector& spirv, uint32_t block_size, kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, - int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02) { +static void ggml_vk_mul_mat_q4_x( + const std::vector& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, + int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02 +) { struct PushConstants { uint32_t inAOff, inBOff, outOff; int32_t ne00, ne10, ne0, ne1, ne01, gqa; @@ -1044,12 +1052,13 @@ static void ggml_vk_mul_mat_q4_x(const std::vector& spirv, uint32_t bl ne00, ne10, ne0, ne1, ne01, ne12/ne02 }; + auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { + if (!komputeManager()->hasAlgorithm(name)) { const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts}); } else { - s_algo = komputeManager()->getAlgorithm(__func__); + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({inA, inB, out}); s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)}); s_algo->setPushConstants({pushConsts}); @@ -1063,7 +1072,7 @@ static void ggml_vk_mul_mat_q4_0(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv, kp::shader_data::op_mul_mat_q4_0_comp_spv_len); - ggml_vk_mul_mat_q4_x(spirv, 1/*We access blocks unaligned*/, std::forward(args)...); + ggml_vk_mul_mat_q4_x(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward(args)...); } template @@ -1071,7 +1080,7 @@ static void ggml_vk_mul_mat_q4_1(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv, kp::shader_data::op_mul_mat_q4_1_comp_spv_len); - ggml_vk_mul_mat_q4_x(spirv, 1/*We access blocks unaligned*/, std::forward(args)...); + ggml_vk_mul_mat_q4_x(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward(args)...); } static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq, @@ -1242,16 +1251,18 @@ static void ggml_vk_rope( seq.record(s_algo); } -template -static void ggml_vk_cpy(const std::vector& spirv, - kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, - uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, - int32_t ne0, int32_t ne1, int32_t ne2, - uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3) { +static void ggml_vk_cpy( + const std::vector& spirv, + uint32_t in_element_size, uint32_t out_element_size, + kp::Sequence& seq, + const std::shared_ptr& in, + const std::shared_ptr& out, + uint32_t inOff, uint32_t outOff, + int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, + uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, + int32_t ne0, int32_t ne1, int32_t ne2, + uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3 +) { struct PushConstants { uint32_t inOff, outOff; int32_t ne00, ne01, ne02; @@ -1266,14 +1277,14 @@ static void ggml_vk_cpy(const std::vector& spirv, nb0, nb1, nb2, nb3 }; - static std::string unique_name = std::string(__func__) + - "_i_" + std::to_string(in_element_size) + - "_o_" + std::to_string(out_element_size); + std::string name = std::string(__func__) + + "_i_" + std::to_string(in_element_size) + + "_o_" + std::to_string(out_element_size); std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(unique_name)) - s_algo = komputeManager()->algorithm(unique_name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); + if (!komputeManager()->hasAlgorithm(name)) + s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); else { - s_algo = komputeManager()->getAlgorithm(unique_name); + s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({in, out}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); @@ -1286,28 +1297,28 @@ template static void ggml_vk_cpy_f32_f16(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv, kp::shader_data::op_cpy_f32_f16_comp_spv_len); - ggml_vk_cpy<4, 2>(spirv, std::forward(args)...); + ggml_vk_cpy(spirv, 4, 2, std::forward(args)...); } template static void ggml_vk_cpy_f32_f32(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv, kp::shader_data::op_cpy_f32_f32_comp_spv_len); - ggml_vk_cpy<4, 4>(spirv, std::forward(args)...); + ggml_vk_cpy(spirv, 4, 4, std::forward(args)...); } template static void ggml_vk_cpy_f16_f16(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv, kp::shader_data::op_cpy_f16_f16_comp_spv_len); - ggml_vk_cpy<2, 2>(spirv, std::forward(args)...); + ggml_vk_cpy(spirv, 2, 2, std::forward(args)...); } template static void ggml_vk_cpy_f16_f32(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv, kp::shader_data::op_cpy_f16_f32_comp_spv_len); - ggml_vk_cpy<2, 4>(spirv, std::forward(args)...); + ggml_vk_cpy(spirv, 2, 4, std::forward(args)...); } static bool ggml_vk_supports_op(const struct ggml_tensor * op) {