From 08e23fd78ca2afdbd0388f66e808851324634428 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Mon, 22 Jan 2024 16:08:16 -0500 Subject: [PATCH] kompute : fix op_mul kernel -> 13 less test failures --- ggml-kompute.cpp | 53 +++++++++++++++++++++++++++---------- kompute-shaders/op_mul.comp | 40 +++++++++++++++++++++++----- 2 files changed, 73 insertions(+), 20 deletions(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 86bd0d78b..76280501f 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -559,7 +559,6 @@ static void ggml_vk_add( int32_t ne0, int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3 ) { - const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv, kp::shader_data::op_add_comp_spv_len); @@ -625,29 +624,47 @@ static void ggml_vk_addrow(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_mul(kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - uint32_t size) { - +static void ggml_vk_mul( + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, + int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03, + int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, + int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13, + int32_t ne0, + int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3 +) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv, kp::shader_data::op_mul_comp_spv_len); struct PushConstants { uint32_t inAOff, inBOff, outOff; + int32_t ne00; + int32_t nb00, nb01, nb02, nb03; + int32_t ne10, ne11, ne12, ne13; + int32_t nb10, nb11, nb12, nb13; + int32_t ne0; + int32_t nb0, nb1, nb2, nb3; } const pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4) + safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), + ne00, + nb00, nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb10, nb11, nb12, nb13, + ne0, + nb0, nb1, nb2, nb3 }; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); - else { + if (!komputeManager()->hasAlgorithm(__func__)) { + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); + } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({size}); + s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); } @@ -1492,7 +1509,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph // src1 is a row ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); } else { - ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); + ggml_vk_mul( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb10, nb11, nb12, nb13, + ne0, + nb0, nb1, nb2, nb3 + ); } } break; case GGML_OP_SCALE: diff --git a/kompute-shaders/op_mul.comp b/kompute-shaders/op_mul.comp index d599460c3..c92647c4d 100644 --- a/kompute-shaders/op_mul.comp +++ b/kompute-shaders/op_mul.comp @@ -2,7 +2,7 @@ #include "common.comp" -layout(local_size_x = 1) in; +layout(local_size_x = 1024) in; layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; @@ -12,13 +12,41 @@ layout(push_constant) uniform PushConstants { uint inAOff; uint inBOff; uint outOff; + int ne00; + int nb00; + int nb01; + int nb02; + int nb03; + int ne10; + int ne11; + int ne12; + int ne13; + int nb10; + int nb11; + int nb12; + int nb13; + int ne0; + int nb0; + int nb1; + int nb2; + int nb3; } pcs; void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; - for (uint x = 0; x < 4; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff]; + const uint i13 = i03 % pcs.ne13; + const uint i12 = i02 % pcs.ne12; + const uint i11 = i01 % pcs.ne11; + + uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4); + uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4); + uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4); + + for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) { + const uint i10 = i0 % pcs.ne10; + out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10]; } -} \ No newline at end of file +}