Optimized ggml_vk_mul_mat_f16 argument count

This commit is contained in:
niansa 2023-07-05 13:34:01 +02:00
parent 6be93e6071
commit 856b7589e9

View file

@ -891,13 +891,8 @@ layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
int64_t ne00;
int64_t ne01;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
int64_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int64_t ne0;
@ -945,21 +940,20 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
int64_t ne00, int64_t ne01,
uint64_t nb00, uint64_t nb01, uint64_t nb02,
int64_t ne10, int64_t ne11, int64_t ne12,
uint64_t nb10, uint64_t nb11, uint64_t nb12,
uint64_t nb01, uint64_t nb02,
int64_t ne11, int64_t ne12,
uint64_t nb11, uint64_t nb12,
int64_t ne0, int64_t ne1) {
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
struct PushConstants {
int64_t ne00, ne01;
uint64_t nb00, nb01, nb02;
int64_t ne10, ne11;
uint64_t nb10, nb11, nb12;
int64_t ne00;
uint64_t nb01, nb02;
uint64_t nb11, nb12;
int64_t ne0, ne1;
uint32_t inAOff, inBOff, outOff;
} pushConsts {
ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1,
ne00, nb01, nb02, nb11, nb12, ne0, ne1,
inAOff, inBOff, outOff
};
@ -1085,7 +1079,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
if (src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1);
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1);
break;
}
}