Added vk_mul to ggml_vk_graph_compute

This commit is contained in:
niansa 2023-06-23 14:19:31 +02:00
parent 18d6f7f8da
commit b6264542b7

View file

@ -286,6 +286,7 @@ layout(push_constant) uniform PushConstants {
uint inAOff; uint inAOff;
uint inBOff; uint inBOff;
uint outOff; uint outOff;
uint row;
} pcs; } pcs;
@ -298,20 +299,23 @@ layout(binding = 2) buffer tensorout { float out[]; };
void main() { void main() {
const int i = int(gl_GlobalInvocationID.x); const int i = int(gl_GlobalInvocationID.x);
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i]; out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
} }
); );
template<char mathOP> template<char mathOP>
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff, void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff, const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
std::shared_ptr<kp::Tensor>& out, uint32_t outOff) { std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath); uint32_t row = 0) {
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n"
"#define ROW_OP "+(row?"% pcs.row":"")+"\n"
+program_abmath);
struct PushConstants { struct PushConstants {
uint32_t inAOff, inBOff, outOff; uint32_t inAOff, inBOff, outOff, row;
} pushConsts { } pushConsts {
inAOff, inBOff, outOff inAOff, inBOff, outOff, row
}; };
mgr.sequence() mgr.sequence()
@ -334,7 +338,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
const int n_seq = gf->n_threads; const int n_seq = gf->n_threads;
std::vector<kp::Sequence> sequences(n_seq); std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
for (auto& sequence : sequences) {
sequence = mgr.sequence();
}
std::vector<std::thread> threads(n_seq); std::vector<std::thread> threads(n_seq);
@ -346,7 +354,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
auto& seq = sequences[seq_idx]; auto& seq = *sequences[seq_idx];
const int node_start = (seq_idx + 0) * n_nodes_per_seq; const int node_start = (seq_idx + 0) * n_nodes_per_seq;
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq; const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
@ -408,6 +416,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst); ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
} break; } break;
case GGML_OP_MUL:
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00);
} else {
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
}
} break;
} }
} }
}); });