diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 3e7fe30a6..57e1ebf6f 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -286,6 +286,7 @@ layout(push_constant) uniform PushConstants { uint inAOff; uint inBOff; uint outOff; + uint row; } pcs; @@ -298,20 +299,23 @@ layout(binding = 2) buffer tensorout { float out[]; }; void main() { const int i = int(gl_GlobalInvocationID.x); - out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i]; + out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)]; } ); template void ggml_vk_abmath(const std::shared_ptr& inA, uint32_t inAOff, const std::shared_ptr& inB, uint32_t inBOff, - std::shared_ptr& out, uint32_t outOff) { - const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath); + std::shared_ptr& out, uint32_t outOff, + uint32_t row = 0) { + const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n" + "#define ROW_OP "+(row?"% pcs.row":"")+"\n" + +program_abmath); struct PushConstants { - uint32_t inAOff, inBOff, outOff; + uint32_t inAOff, inBOff, outOff, row; } pushConsts { - inAOff, inBOff, outOff + inAOff, inBOff, outOff, row }; mgr.sequence() @@ -334,7 +338,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph const int n_seq = gf->n_threads; - std::vector sequences(n_seq); + std::vector> sequences(n_seq); + + for (auto& sequence : sequences) { + sequence = mgr.sequence(); + } std::vector threads(n_seq); @@ -346,7 +354,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph size_t offs_src1 = 0; size_t offs_dst = 0; - auto& seq = sequences[seq_idx]; + auto& seq = *sequences[seq_idx]; const int node_start = (seq_idx + 0) * n_nodes_per_seq; const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq; @@ -408,6 +416,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); } break; + case GGML_OP_MUL: + { + if (ggml_nelements(src1) == ne10) { + // src1 is a row + ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00); + } else { + ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); + } + } break; } } });