Optimized ggml_vk_mul_mat_f16 argument count
This commit is contained in:
parent
6be93e6071
commit
856b7589e9
1 changed files with 8 additions and 14 deletions
|
@ -891,13 +891,8 @@ layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
|||
|
||||
layout (push_constant) uniform parameter {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int64_t ne10;
|
||||
int64_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int64_t ne0;
|
||||
|
@ -945,21 +940,20 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
|||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
int64_t ne00, int64_t ne01,
|
||||
uint64_t nb00, uint64_t nb01, uint64_t nb02,
|
||||
int64_t ne10, int64_t ne11, int64_t ne12,
|
||||
uint64_t nb10, uint64_t nb11, uint64_t nb12,
|
||||
uint64_t nb01, uint64_t nb02,
|
||||
int64_t ne11, int64_t ne12,
|
||||
uint64_t nb11, uint64_t nb12,
|
||||
int64_t ne0, int64_t ne1) {
|
||||
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
|
||||
|
||||
struct PushConstants {
|
||||
int64_t ne00, ne01;
|
||||
uint64_t nb00, nb01, nb02;
|
||||
int64_t ne10, ne11;
|
||||
uint64_t nb10, nb11, nb12;
|
||||
int64_t ne00;
|
||||
uint64_t nb01, nb02;
|
||||
uint64_t nb11, nb12;
|
||||
int64_t ne0, ne1;
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
} pushConsts {
|
||||
ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1,
|
||||
ne00, nb01, nb02, nb11, nb12, ne0, ne1,
|
||||
inAOff, inBOff, outOff
|
||||
};
|
||||
|
||||
|
@ -1085,7 +1079,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
{
|
||||
if (src0->type == GGML_TYPE_F16
|
||||
&& src1->type == GGML_TYPE_F32) {
|
||||
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1);
|
||||
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue