Added vk_mul to ggml_vk_graph_compute
This commit is contained in:
parent
18d6f7f8da
commit
b6264542b7
1 changed files with 24 additions and 7 deletions
|
@ -286,6 +286,7 @@ layout(push_constant) uniform PushConstants {
|
||||||
uint inAOff;
|
uint inAOff;
|
||||||
uint inBOff;
|
uint inBOff;
|
||||||
uint outOff;
|
uint outOff;
|
||||||
|
uint row;
|
||||||
} pcs;
|
} pcs;
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,20 +299,23 @@ layout(binding = 2) buffer tensorout { float out[]; };
|
||||||
void main() {
|
void main() {
|
||||||
const int i = int(gl_GlobalInvocationID.x);
|
const int i = int(gl_GlobalInvocationID.x);
|
||||||
|
|
||||||
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i];
|
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
template<char mathOP>
|
template<char mathOP>
|
||||||
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||||
std::shared_ptr<kp::Tensor>& out, uint32_t outOff) {
|
std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||||
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath);
|
uint32_t row = 0) {
|
||||||
|
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n"
|
||||||
|
"#define ROW_OP "+(row?"% pcs.row":"")+"\n"
|
||||||
|
+program_abmath);
|
||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inAOff, inBOff, outOff;
|
uint32_t inAOff, inBOff, outOff, row;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
inAOff, inBOff, outOff
|
inAOff, inBOff, outOff, row
|
||||||
};
|
};
|
||||||
|
|
||||||
mgr.sequence()
|
mgr.sequence()
|
||||||
|
@ -334,7 +338,11 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
|
|
||||||
const int n_seq = gf->n_threads;
|
const int n_seq = gf->n_threads;
|
||||||
|
|
||||||
std::vector<kp::Sequence> sequences(n_seq);
|
std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
|
||||||
|
|
||||||
|
for (auto& sequence : sequences) {
|
||||||
|
sequence = mgr.sequence();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::thread> threads(n_seq);
|
std::vector<std::thread> threads(n_seq);
|
||||||
|
|
||||||
|
@ -346,7 +354,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
size_t offs_src1 = 0;
|
size_t offs_src1 = 0;
|
||||||
size_t offs_dst = 0;
|
size_t offs_dst = 0;
|
||||||
|
|
||||||
auto& seq = sequences[seq_idx];
|
auto& seq = *sequences[seq_idx];
|
||||||
|
|
||||||
const int node_start = (seq_idx + 0) * n_nodes_per_seq;
|
const int node_start = (seq_idx + 0) * n_nodes_per_seq;
|
||||||
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
|
const int node_end = (seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq;
|
||||||
|
@ -408,6 +416,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
{
|
{
|
||||||
ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
|
ggml_vk_add(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
{
|
||||||
|
if (ggml_nelements(src1) == ne10) {
|
||||||
|
// src1 is a row
|
||||||
|
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00);
|
||||||
|
} else {
|
||||||
|
ggml_vk_mul(id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue