kompute : fix more dispatch ambiguity -> 12 less failures
This commit is contained in:
parent
08e23fd78c
commit
2755ae3d10
1 changed files with 63 additions and 52 deletions
115
ggml-kompute.cpp
115
ggml-kompute.cpp
|
@ -743,22 +743,25 @@ static void ggml_vk_scale(kp::Sequence& seq,
|
|||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_xxlu(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
uint32_t size) {
|
||||
static void ggml_vk_xxlu(
|
||||
const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
uint32_t size
|
||||
) {
|
||||
struct PushConstants {
|
||||
uint32_t inOff, outOff;
|
||||
} const pushConsts {
|
||||
safe_divide(inOff, 4), safe_divide(outOff, 4),
|
||||
};
|
||||
|
||||
auto name = std::string(__func__) + "_" + suffix;
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__))
|
||||
if (!komputeManager()->hasAlgorithm(name)) {
|
||||
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
|
||||
else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(name);
|
||||
s_algo->setTensors({in, out});
|
||||
s_algo->setWorkgroup({size});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
|
@ -772,7 +775,7 @@ static void ggml_vk_silu(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
|
||||
kp::shader_data::op_silu_comp_spv_len);
|
||||
|
||||
ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -780,7 +783,7 @@ static void ggml_vk_relu(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
|
||||
kp::shader_data::op_relu_comp_spv_len);
|
||||
|
||||
ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -788,7 +791,7 @@ static void ggml_vk_gelu(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
|
||||
kp::shader_data::op_gelu_comp_spv_len);
|
||||
|
||||
ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_soft_max(kp::Sequence& seq,
|
||||
|
@ -823,12 +826,14 @@ static void ggml_vk_soft_max(kp::Sequence& seq,
|
|||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_norm_(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t nb01,
|
||||
int32_t nrows, float epsilon) {
|
||||
static void ggml_vk_norm_(
|
||||
const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t nb01,
|
||||
int32_t nrows, float epsilon
|
||||
) {
|
||||
GGML_ASSERT(nb01%sizeof(float) == 0);
|
||||
GGML_ASSERT(ne00%sizeof(float) == 0);
|
||||
|
||||
|
@ -841,11 +846,12 @@ static void ggml_vk_norm_(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
|
|||
(uint32_t)ne00, (uint32_t)nb01, epsilon
|
||||
};
|
||||
|
||||
auto name = std::string(__func__) + "_" + suffix;
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
if (!komputeManager()->hasAlgorithm(name)) {
|
||||
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo = komputeManager()->getAlgorithm(name);
|
||||
s_algo->setTensors({in, out});
|
||||
s_algo->setWorkgroup({(uint32_t)nrows});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
|
@ -859,7 +865,7 @@ static void ggml_vk_norm(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
|
||||
kp::shader_data::op_norm_comp_spv_len);
|
||||
|
||||
ggml_vk_norm_(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -867,7 +873,7 @@ static void ggml_vk_rms_norm(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
|
||||
kp::shader_data::op_rmsnorm_comp_spv_len);
|
||||
|
||||
ggml_vk_norm_(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
|
||||
|
@ -1029,13 +1035,15 @@ static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
|
|||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q4_x(const std::vector<uint32_t>& spirv, uint32_t block_size, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
|
||||
int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02) {
|
||||
static void ggml_vk_mul_mat_q4_x(
|
||||
const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
|
||||
int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
|
||||
) {
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t ne00, ne10, ne0, ne1, ne01, gqa;
|
||||
|
@ -1044,12 +1052,13 @@ static void ggml_vk_mul_mat_q4_x(const std::vector<uint32_t>& spirv, uint32_t bl
|
|||
ne00, ne10, ne0, ne1, ne01, ne12/ne02
|
||||
};
|
||||
|
||||
auto name = std::string(__func__) + "_" + suffix;
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
if (!komputeManager()->hasAlgorithm(name)) {
|
||||
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo = komputeManager()->getAlgorithm(name);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
|
@ -1063,7 +1072,7 @@ static void ggml_vk_mul_mat_q4_0(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
|
||||
kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
|
||||
|
||||
ggml_vk_mul_mat_q4_x(spirv, 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
ggml_vk_mul_mat_q4_x(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -1071,7 +1080,7 @@ static void ggml_vk_mul_mat_q4_1(Args&&... args) {
|
|||
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
|
||||
kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
|
||||
|
||||
ggml_vk_mul_mat_q4_x(spirv, 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
ggml_vk_mul_mat_q4_x(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq,
|
||||
|
@ -1242,16 +1251,18 @@ static void ggml_vk_rope(
|
|||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
template<uint32_t in_element_size, uint32_t out_element_size>
|
||||
static void ggml_vk_cpy(const std::vector<uint32_t>& spirv,
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
|
||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||
int32_t ne0, int32_t ne1, int32_t ne2,
|
||||
uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3) {
|
||||
static void ggml_vk_cpy(
|
||||
const std::vector<uint32_t>& spirv,
|
||||
uint32_t in_element_size, uint32_t out_element_size,
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
|
||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||
int32_t ne0, int32_t ne1, int32_t ne2,
|
||||
uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
|
||||
) {
|
||||
struct PushConstants {
|
||||
uint32_t inOff, outOff;
|
||||
int32_t ne00, ne01, ne02;
|
||||
|
@ -1266,14 +1277,14 @@ static void ggml_vk_cpy(const std::vector<uint32_t>& spirv,
|
|||
nb0, nb1, nb2, nb3
|
||||
};
|
||||
|
||||
static std::string unique_name = std::string(__func__) +
|
||||
"_i_" + std::to_string(in_element_size) +
|
||||
"_o_" + std::to_string(out_element_size);
|
||||
std::string name = std::string(__func__)
|
||||
+ "_i_" + std::to_string(in_element_size)
|
||||
+ "_o_" + std::to_string(out_element_size);
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(unique_name))
|
||||
s_algo = komputeManager()->algorithm<float, PushConstants>(unique_name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
|
||||
if (!komputeManager()->hasAlgorithm(name))
|
||||
s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
|
||||
else {
|
||||
s_algo = komputeManager()->getAlgorithm(unique_name);
|
||||
s_algo = komputeManager()->getAlgorithm(name);
|
||||
s_algo->setTensors({in, out});
|
||||
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
|
@ -1286,28 +1297,28 @@ template <typename... Args>
|
|||
static void ggml_vk_cpy_f32_f16(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
|
||||
kp::shader_data::op_cpy_f32_f16_comp_spv_len);
|
||||
ggml_vk_cpy<4, 2>(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_cpy_f32_f32(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
|
||||
kp::shader_data::op_cpy_f32_f32_comp_spv_len);
|
||||
ggml_vk_cpy<4, 4>(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_cpy_f16_f16(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
|
||||
kp::shader_data::op_cpy_f16_f16_comp_spv_len);
|
||||
ggml_vk_cpy<2, 2>(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
|
||||
kp::shader_data::op_cpy_f16_f32_comp_spv_len);
|
||||
ggml_vk_cpy<2, 4>(spirv, std::forward<Args>(args)...);
|
||||
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue