kompute : fix get_rows dispatch -> 4 less failures
This commit is contained in:
parent
cb9ceff966
commit
0899adf86e
1 changed files with 20 additions and 16 deletions
|
@ -1089,15 +1089,18 @@ static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq,
|
|||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows(const std::vector<uint32_t>& spirv,
|
||||
unsigned element_size, unsigned qk,
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t nb01, int32_t nb1,
|
||||
uint32_t size) {
|
||||
static void ggml_vk_get_rows(
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const char * suffix,
|
||||
unsigned element_size, unsigned qk,
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t nb01, int32_t nb1,
|
||||
uint32_t size
|
||||
) {
|
||||
GGML_ASSERT(nb01%element_size == 0);
|
||||
GGML_ASSERT(nb1%sizeof(float) == 0);
|
||||
if (qk) GGML_ASSERT(ne00%qk == 0);
|
||||
|
@ -1110,11 +1113,12 @@ static void ggml_vk_get_rows(const std::vector<uint32_t>& spirv,
|
|||
ne00, nb01, nb1
|
||||
};
|
||||
|
||||
auto name = std::string(__func__) + "_" + suffix;
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__))
|
||||
if (!komputeManager()->hasAlgorithm(name)) {
|
||||
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
|
||||
else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(name);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
s_algo->setWorkgroup({size});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
|
@ -1128,7 +1132,7 @@ static void ggml_vk_get_rows_f16(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
|
||||
kp::shader_data::op_getrows_f16_comp_spv_len);
|
||||
|
||||
ggml_vk_get_rows(spirv, sizeof(half), 0, std::forward<Args>(args)...);
|
||||
ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -1136,7 +1140,7 @@ static void ggml_vk_get_rows_q4_0(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
|
||||
kp::shader_data::op_getrows_q4_0_comp_spv_len);
|
||||
|
||||
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
|
||||
ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -1144,14 +1148,14 @@ static void ggml_vk_get_rows_q4_1(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
|
||||
kp::shader_data::op_getrows_q4_1_comp_spv_len);
|
||||
|
||||
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
|
||||
ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_get_rows_q6_k(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
|
||||
kp::shader_data::op_getrows_q6_k_comp_spv_len);
|
||||
ggml_vk_get_rows(spirv, 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
|
||||
ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue