kompute : fix get_rows dispatch -> 4 less failures

This commit is contained in:
Jared Van Bortel 2024-01-22 14:16:10 -05:00
parent cb9ceff966
commit 0899adf86e

View file

@ -1089,7 +1089,9 @@ static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq,
seq.record<kp::OpAlgoDispatch>(s_algo); seq.record<kp::OpAlgoDispatch>(s_algo);
} }
static void ggml_vk_get_rows(const std::vector<uint32_t>& spirv, static void ggml_vk_get_rows(
const std::vector<uint32_t>& spirv,
const char * suffix,
unsigned element_size, unsigned qk, unsigned element_size, unsigned qk,
kp::Sequence& seq, kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, const std::shared_ptr<kp::Tensor>& inA,
@ -1097,7 +1099,8 @@ static void ggml_vk_get_rows(const std::vector<uint32_t>& spirv,
const std::shared_ptr<kp::Tensor>& out, const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff, uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t nb01, int32_t nb1, int32_t ne00, int32_t nb01, int32_t nb1,
uint32_t size) { uint32_t size
) {
GGML_ASSERT(nb01%element_size == 0); GGML_ASSERT(nb01%element_size == 0);
GGML_ASSERT(nb1%sizeof(float) == 0); GGML_ASSERT(nb1%sizeof(float) == 0);
if (qk) GGML_ASSERT(ne00%qk == 0); if (qk) GGML_ASSERT(ne00%qk == 0);
@ -1110,11 +1113,12 @@ static void ggml_vk_get_rows(const std::vector<uint32_t>& spirv,
ne00, nb01, nb1 ne00, nb01, nb1
}; };
auto name = std::string(__func__) + "_" + suffix;
std::shared_ptr<kp::Algorithm> s_algo = nullptr; std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) if (!komputeManager()->hasAlgorithm(name)) {
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
else { } else {
s_algo = komputeManager()->getAlgorithm(__func__); s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({inA, inB, out}); s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({size}); s_algo->setWorkgroup({size});
s_algo->setPushConstants<PushConstants>({pushConsts}); s_algo->setPushConstants<PushConstants>({pushConsts});
@ -1128,7 +1132,7 @@ static void ggml_vk_get_rows_f16(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
kp::shader_data::op_getrows_f16_comp_spv_len); kp::shader_data::op_getrows_f16_comp_spv_len);
ggml_vk_get_rows(spirv, sizeof(half), 0, std::forward<Args>(args)...); ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
} }
template <typename... Args> template <typename... Args>
@ -1136,7 +1140,7 @@ static void ggml_vk_get_rows_q4_0(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv, const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
kp::shader_data::op_getrows_q4_0_comp_spv_len); kp::shader_data::op_getrows_q4_0_comp_spv_len);
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...); ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
} }
template <typename... Args> template <typename... Args>
@ -1144,14 +1148,14 @@ static void ggml_vk_get_rows_q4_1(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv, const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
kp::shader_data::op_getrows_q4_1_comp_spv_len); kp::shader_data::op_getrows_q4_1_comp_spv_len);
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...); ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
} }
template <typename... Args> template <typename... Args>
static void ggml_vk_get_rows_q6_k(Args&&... args) { static void ggml_vk_get_rows_q6_k(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv, const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
kp::shader_data::op_getrows_q6_k_comp_spv_len); kp::shader_data::op_getrows_q6_k_comp_spv_len);
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...); ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
} }
static void ggml_vk_rope( static void ggml_vk_rope(