From 6be93e607149e94550c1ba2fa273cdbaa64f2815 Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 5 Jul 2023 13:28:40 +0200 Subject: [PATCH] Ported mat mul from Metal --- ggml-vulkan.cpp | 188 ++++++++++++------------------------------------ 1 file changed, 47 insertions(+), 141 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 517b98135..5f1b8d43a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -173,7 +173,7 @@ static std::vector glsl_compile_source(const std::string& source, cons std::ofstream fileOut("tmp_kp_shader.comp"); fileOut << source; fileOut.close(); - if (system(std::string("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null").c_str())) + if (system("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null")) throw std::runtime_error("Error running glslangValidator command"); std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary); std::vector buffer; @@ -883,131 +883,59 @@ void ggml_vk_diag_mask_inf(kp::Sequence& seq, static const std::string program_mul_mat_f16 = MULTILINE_QUOTE( -layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 64) in; layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; layout (push_constant) uniform parameter { - int M; - int N; - int K; - int inAStride; - int inBStride; - int outStride; + int64_t ne00; + int64_t ne01; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + int64_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int64_t ne0; + int64_t ne1; uint inAOff; uint inBOff; uint outOff; } pcs; -shared float16_t bufA[BM * (BK+1)]; -shared float16_t bufB[BN * (BK+1)]; +shared float sum[gl_WorkGroupSize.x]; void main() { - const int ir = int(gl_WorkGroupID.x); - const int ic = int(gl_WorkGroupID.y); + const int64_t r0 = gl_GlobalInvocationID.x; + const int64_t r1 = gl_GlobalInvocationID.y; + const int64_t im = gl_GlobalInvocationID.z; - const int rstride = BM / TM; + const uint x = uint((r0*pcs.nb01 + im*pcs.nb02) / 2); // Based from inA + const uint y = uint((r1*pcs.nb11 + im*pcs.nb12) / 4); // based from inB - const int lr = int(gl_LocalInvocationID.x % rstride); - const int lc = int(gl_LocalInvocationID.x / rstride); + sum[gl_WorkGroupID.x] = 0.0f; - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - int posA = ir * BM * pcs.inAStride; - int posB = ic * BN * pcs.inBStride; - - float sums[TM * TN]; - float16_t cacheA[TM]; - float16_t cacheB[TN]; - - [[unroll]] for (int i = 0; i < TM*TN; i++) { - sums[i] = 0.0hf; + for (uint i = gl_WorkGroupID.x; i < pcs.ne00; i += gl_WorkGroupSize.x) { + sum[gl_WorkGroupID.x] += float(inA[x+i]) * float(inB[y+i]); } - [[unroll]] for (int block = 0; block < pcs.K; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr]; + // accumulate the sum from all threads in the threadgroup + barrier(); + memoryBarrierShared(); + for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) { + if (gl_WorkGroupID.x < i) { + sum[gl_WorkGroupID.x] += sum[gl_WorkGroupID.x + i]; } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr]; - } - - barrier(); - - posA += BK; - posB += BK; - - [[unroll]] for (int i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (int j = 0; j < BM; j++) { - cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i]; - } - [[unroll]] for (int j = 0; j < TN; j++) { - cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i]; - } - - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[cc * TM + cr] += float(cacheA[cr]) * float(cacheB[cc]); - } - } - } - barrier(); + memoryBarrierShared(); } - const int dr = ir * BM + lr; - const int dc = ic * BN + lc * TN; - - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - out_[(dc + cc) * pcs.outStride + dr + cr*rstride] = sums[cc * TM + cr]; - } - } -} -); - -static const std::string program_fast_mul_mat_f16 = - MULTILINE_QUOTE( -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - int M; - int N; - int K; - int inAStride; - int inBStride; - int outStride; - uint inAOff; - uint inBOff; - uint outOff; -} pcs; - -void main() { - int row = int(gl_GlobalInvocationID.x); - int col = int(gl_GlobalInvocationID.y); - - if (row < pcs.M && col < pcs.N) { - float sum = 0.0f; - - for (int i = 0; i < pcs.K; i++) { - sum += float(inA[row * pcs.inAStride + i]) * float(inB[col * pcs.inBStride + i]); - } - - out_[col * pcs.outStride + row] = sum; + if (gl_WorkGroupID.x == 0) { + out_[uint(im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0)] = sum[0]; } } ); @@ -1016,48 +944,26 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq, const std::shared_ptr& inA, uint32_t inAOff, const std::shared_ptr& inB, uint32_t inBOff, const std::shared_ptr& out, uint32_t outOff, - int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03, - int64_t ne10, int64_t ne11, - int nb10, int nb11, int nb12, int nb13, - int nb2, int nb3) { - const static auto spirv = glsl_compile_source(program_source_head+program_fast_mul_mat_f16, __func__); - - const bool inB_cont_rows = nb10 == sizeof(float); - const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float); + int64_t ne00, int64_t ne01, + uint64_t nb00, uint64_t nb01, uint64_t nb02, + int64_t ne10, int64_t ne11, int64_t ne12, + uint64_t nb10, uint64_t nb11, uint64_t nb12, + int64_t ne0, int64_t ne1) { + const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__); struct PushConstants { - int32_t M, N, K, inAStride, inBStride, outStride; + int64_t ne00, ne01; + uint64_t nb00, nb01, nb02; + int64_t ne10, ne11; + uint64_t nb10, nb11, nb12; + int64_t ne0, ne1; uint32_t inAOff, inBOff, outOff; } pushConsts { - (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01, + ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1, inAOff, inBOff, outOff }; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - auto tmp = mgr.tensorT(std::vector(ne10*ne11)); - - if (inB_cont_rows) { - if (inB_cont_cols) { - ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12)/sizeof(float), tmp, 0, ne10*ne11); - } - else { - for (int64_t i01 = 0; i01 < ne11; i01++) { - ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11)/sizeof(float), tmp, i01*ne10, ne10); - } - } - } else { - for (int64_t i01 = 0; i01 < ne11; i01++) { - for (int64_t i00 = 0; i00 < ne10; i00++) { - // Extremely slow because of single shader invocation - ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)/sizeof(float), tmp, i01*ne10 + i00, 1); - } - } - } - - seq.record(mgr.algorithm({inA, tmp, out}, spirv, {uint32_t(ne01/128), uint32_t(ne11/128)}, {}, {pushConsts})); - } - } + seq.record(mgr.algorithm({inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne11), unsigned(ne12)}, {}, {pushConsts})); } @@ -1179,7 +1085,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3); + ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1); break; } }