diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 35d31157b..517b98135 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -976,6 +976,42 @@ void main() { } ); +static const std::string program_fast_mul_mat_f16 = + MULTILINE_QUOTE( +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + int M; + int N; + int K; + int inAStride; + int inBStride; + int outStride; + uint inAOff; + uint inBOff; + uint outOff; +} pcs; + +void main() { + int row = int(gl_GlobalInvocationID.x); + int col = int(gl_GlobalInvocationID.y); + + if (row < pcs.M && col < pcs.N) { + float sum = 0.0f; + + for (int i = 0; i < pcs.K; i++) { + sum += float(inA[row * pcs.inAStride + i]) * float(inB[col * pcs.inBStride + i]); + } + + out_[col * pcs.outStride + row] = sum; + } +} +); + void ggml_vk_mul_mat_f16(kp::Sequence& seq, const std::shared_ptr& inA, uint32_t inAOff, const std::shared_ptr& inB, uint32_t inBOff, @@ -984,7 +1020,7 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq, int64_t ne10, int64_t ne11, int nb10, int nb11, int nb12, int nb13, int nb2, int nb3) { - const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__); + const static auto spirv = glsl_compile_source(program_source_head+program_fast_mul_mat_f16, __func__); const bool inB_cont_rows = nb10 == sizeof(float); const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float); @@ -1025,131 +1061,6 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq, } -static const std::string program_mul_mat_f32 = - MULTILINE_QUOTE( -layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { float inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - int M; - int N; - int K; - int inAStride; - int inBStride; - int outStride; - uint inAOff; - uint inBOff; - uint outOff; -} pcs; - -shared float bufA[BM * (BK+1)]; -shared float bufB[BN * (BK+1)]; - -void main() { - const int ir = int(gl_WorkGroupID.x); - const int ic = int(gl_WorkGroupID.y); - - const int rstride = BM / TM; - - const int lr = int(gl_WorkGroupID.x % rstride); - const int lc = int(gl_WorkGroupID.x / rstride); - - const int loadr = int(gl_WorkGroupID.x % BK); - const int loadc = int(gl_WorkGroupID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - int posA = ir * BM * pcs.inAStride; - int posB = ic * BN * pcs.inBStride; - - float sums[TM * TN]; - float cacheA[TM]; - float cacheB[TN]; - - [[unroll]] for (int i = 0; i < TM*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = 0; block < pcs.K; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr + pcs.inAOff]; - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr + pcs.inBOff]; - } - - barrier(); - memoryBarrierShared(); - - posA += BK; - posB += BK; - - [[unroll]] for (int i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (int j = 0; j < BM; j++) { - cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i]; - } - [[unroll]] for (int j = 0; j < TN; j++) { - cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i]; - } - - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[cc * TM + cr] += cacheA[cr] * cacheB[cc]; - } - } - } - - barrier(); - } - - const int dr = ir * BM + lr; - const int dc = ic * BN + lc * TN; - - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - out_[(dc + cc) * pcs.outStride + dr + cr*rstride + pcs.outOff] = sums[cc * TM + cr]; - } - } -} -); - -void ggml_vk_mul_mat_f32(kp::Sequence& seq, - const std::shared_ptr& inA, uint32_t inAOff, - const std::shared_ptr& inB, uint32_t inBOff, - const std::shared_ptr& out, uint32_t outOff, - int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03, - int64_t ne10, int64_t ne11, - int nb2, int nb3) { - const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f32, __func__); - - struct PushConstants { - int32_t M, N, K, inAStride, inBStride, outStride; - uint32_t inAOff, inBOff, outOff; - } pushConsts { - (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01, - inAOff, inBOff, outOff - }; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - auto off = i02*nb2 + i03*nb3; - pushConsts.inAOff = inAOff + off; - pushConsts.inBOff = inBOff + off; - pushConsts.outOff = outOff + off; - seq.record(mgr.algorithm({inA, inB, out}, spirv, {uint32_t(ne01/128), uint32_t(ne11/128)}, {}, {pushConsts})); - } - } -} - - void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { printf("%s: evaluating graph\n", __func__); @@ -1266,11 +1177,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_MUL_MAT: { - if (src0->type == GGML_TYPE_F32 - && src1->type == GGML_TYPE_F32) { - ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3); - break; - } else if (src0->type == GGML_TYPE_F16 + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3); break;