Fixed ggml_vk_abmath row argument
This commit is contained in:
parent
072007b1e8
commit
ed14f0764a
1 changed files with 7 additions and 5 deletions
|
@ -344,15 +344,17 @@ void main() {
|
|||
}
|
||||
);
|
||||
|
||||
template<char mathOP>
|
||||
template<char mathOP, bool with_row = false>
|
||||
void ggml_vk_abmath(kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
uint32_t size, uint32_t row = 0) {
|
||||
GGML_ASSERT(with_row?row:!row);
|
||||
|
||||
const static auto spirv = compileSource(program_source_head+
|
||||
"#define MATH_OP "+std::string(1, mathOP)+"\n"
|
||||
"#define ROW_OP "+(row?"% pcs.row":"")+'\n'+
|
||||
"#define ROW_OP "+(with_row?"% pcs.row":"")+'\n'+
|
||||
program_abmath, __func__);
|
||||
|
||||
struct PushConstants {
|
||||
|
@ -369,9 +371,9 @@ void ggml_vk_add(Args&&... args) {
|
|||
return ggml_vk_abmath<'+'>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
template <bool with_row = false, typename... Args>
|
||||
void ggml_vk_mul(Args&&... args) {
|
||||
return ggml_vk_abmath<'*'>(std::forward<Args>(args)...);
|
||||
return ggml_vk_abmath<'*', with_row>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
@ -589,7 +591,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
{
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ggml_nelements(dst));
|
||||
ggml_vk_mul<true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
|
||||
} else {
|
||||
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue