Added vk_mul to ggml_vk_graph_compute
This commit is contained in:
parent
18d6f7f8da
commit
b6264542b7
1 changed files with 24 additions and 7 deletions
|
@ -286,6 +286,7 @@ layout(push_constant) uniform PushConstants {
|
|||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
uint row;
|
||||
} pcs;
|
||||
|
||||
|
||||
|
@ -298,20 +299,23 @@ layout(binding = 2) buffer tensorout { float out[]; };
|
|||
void main() {
|
||||
const int i = int(gl_GlobalInvocationID.x);
|
||||
|
||||
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i];
|
||||
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
|
||||
}
|
||||
);
|
||||
|
||||
template<char mathOP>
|
||||
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
std::shared_ptr<kp::Tensor>& out, uint32_t outOff) {
|
||||
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath);
|
||||
std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
uint32_t row = 0) {
|
||||
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n"
|
||||
"#define ROW_OP "+(row?"% pcs.row":"")+"\n"
|
||||
+program_abmath);
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
uint32_t inAOff, inBOff, outOff, row;
|
||||
} pushConsts {
|
||||
inAOff, inBOff, outOff
|
||||
inAOff, inBOff, outOff, row
|
||||
};
|
||||
|
||||
mgr.sequence()
|
||||
|
@ -334,7 +338,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
|
||||
const int n_seq = gf->n_threads;
|
||||
|
||||
std::vector<kp::Sequence> sequences(n_seq);
|
||||
std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
|
||||
|
||||
for (auto& sequence : sequences) {
|
||||
sequence = mgr.sequence();
|
||||
}
|
||||
|
||||
std::vector<std::thread> threads(n_seq);
|
||||
|
||||
|
@ -346,7 +354,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
size_t offs_src1 = 0;
|
||||
size_t offs_dst = 0;
|
||||
|
||||
auto& seq = sequences[seq_idx];
|
||||
auto& seq = *sequences[seq_idx];
|
||||
|
||||
const int node_start = (seq_idx + 0) * n_nodes_per_seq;
|
||||
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
|
||||
|
@ -408,6 +416,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
{
|
||||
ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00);
|
||||
} else {
|
||||
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
|
||||
}
|
||||
} break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue