Fixed ggml_vk_abmath row argument

This commit is contained in:
niansa 2023-06-28 10:15:23 +02:00
parent 072007b1e8
commit ed14f0764a

View file

@ -344,15 +344,17 @@ void main() {
}
);
template<char mathOP>
template<char mathOP, bool with_row = false>
void ggml_vk_abmath(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
uint32_t size, uint32_t row = 0) {
GGML_ASSERT(with_row?row:!row);
const static auto spirv = compileSource(program_source_head+
"#define MATH_OP "+std::string(1, mathOP)+"\n"
"#define ROW_OP "+(row?"% pcs.row":"")+'\n'+
"#define ROW_OP "+(with_row?"% pcs.row":"")+'\n'+
program_abmath, __func__);
struct PushConstants {
@ -369,9 +371,9 @@ void ggml_vk_add(Args&&... args) {
return ggml_vk_abmath<'+'>(std::forward<Args>(args)...);
}
template <typename... Args>
template <bool with_row = false, typename... Args>
void ggml_vk_mul(Args&&... args) {
return ggml_vk_abmath<'*'>(std::forward<Args>(args)...);
return ggml_vk_abmath<'*', with_row>(std::forward<Args>(args)...);
}
@ -589,7 +591,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ggml_nelements(dst));
ggml_vk_mul<true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
}