Added vk_mul to ggml_vk_graph_compute

This commit is contained in:
niansa 2023-06-23 14:19:31 +02:00
parent 18d6f7f8da
commit b6264542b7

View file

@ -286,6 +286,7 @@ layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
uint row;
} pcs;
@ -298,20 +299,23 @@ layout(binding = 2) buffer tensorout { float out[]; };
void main() {
const int i = int(gl_GlobalInvocationID.x);
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i];
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
}
);
template<char mathOP>
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
std::shared_ptr<kp::Tensor>& out, uint32_t outOff) {
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath);
std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
uint32_t row = 0) {
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n"
"#define ROW_OP "+(row?"% pcs.row":"")+"\n"
+program_abmath);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
uint32_t inAOff, inBOff, outOff, row;
} pushConsts {
inAOff, inBOff, outOff
inAOff, inBOff, outOff, row
};
mgr.sequence()
@ -334,7 +338,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
const int n_seq = gf->n_threads;
std::vector<kp::Sequence> sequences(n_seq);
std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
for (auto& sequence : sequences) {
sequence = mgr.sequence();
}
std::vector<std::thread> threads(n_seq);
@ -346,7 +354,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
size_t offs_src1 = 0;
size_t offs_dst = 0;
auto& seq = sequences[seq_idx];
auto& seq = *sequences[seq_idx];
const int node_start = (seq_idx + 0) * n_nodes_per_seq;
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
@ -408,6 +416,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
} break;
case GGML_OP_MUL:
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00);
} else {
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
}
} break;
}
}
});