Fixed ggml_vk_abmath row argument

This commit is contained in:
niansa 2023-06-28 10:15:23 +02:00
parent 072007b1e8
commit ed14f0764a

View file

@ -344,15 +344,17 @@ void main() {
} }
); );
template<char mathOP> template<char mathOP, bool with_row = false>
void ggml_vk_abmath(kp::Sequence& seq, void ggml_vk_abmath(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff, const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff, const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff, const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
uint32_t size, uint32_t row = 0) { uint32_t size, uint32_t row = 0) {
GGML_ASSERT(with_row?row:!row);
const static auto spirv = compileSource(program_source_head+ const static auto spirv = compileSource(program_source_head+
"#define MATH_OP "+std::string(1, mathOP)+"\n" "#define MATH_OP "+std::string(1, mathOP)+"\n"
"#define ROW_OP "+(row?"% pcs.row":"")+'\n'+ "#define ROW_OP "+(with_row?"% pcs.row":"")+'\n'+
program_abmath, __func__); program_abmath, __func__);
struct PushConstants { struct PushConstants {
@ -369,9 +371,9 @@ void ggml_vk_add(Args&&... args) {
return ggml_vk_abmath<'+'>(std::forward<Args>(args)...); return ggml_vk_abmath<'+'>(std::forward<Args>(args)...);
} }
template <typename... Args> template <bool with_row = false, typename... Args>
void ggml_vk_mul(Args&&... args) { void ggml_vk_mul(Args&&... args) {
return ggml_vk_abmath<'*'>(std::forward<Args>(args)...); return ggml_vk_abmath<'*', with_row>(std::forward<Args>(args)...);
} }
@ -589,7 +591,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ggml_nelements(dst)); ggml_vk_mul<true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
} else { } else {
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst)); ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
} }