From 964fe8c546dba2e88e13d6f6d09a62c45008ac61 Mon Sep 17 00:00:00 2001 From: niansa Date: Fri, 30 Jun 2023 11:47:10 +0200 Subject: [PATCH] Added mul_mat (needs fixes) --- ggml-vulkan.cpp | 357 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 343 insertions(+), 14 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 70247a40d..d6b99aa1f 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -217,6 +217,7 @@ static const std::string program_source_head = R"(#version 450 #extension GL_EXT_shader_explicit_arithmetic_types_float16: enable #extension GL_EXT_shader_explicit_arithmetic_types_int8: enable #extension GL_EXT_shader_explicit_arithmetic_types_int64: enable +#extension GL_EXT_control_flow_attributes: enable #define QK4_0 32 #define QR4_0 2 @@ -336,6 +337,44 @@ void ggml_vk_dequantize_row_q4_1(const void *x_, float *y, int k) { } +static const std::string program_fpx_to_fpx = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint inOff; + uint outOff; + uint row; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer restrict readonly tensorIn { IN_TYPE in_[]; }; +layout(binding = 1) buffer restrict writeonly tensorOut { OUT_TYPE out_[]; }; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + out_[pcs.outOff + i] = OUT_TYPE(in_[pcs.inOff + i]); +} +); + +void ggml_vk_fp32_to_fp16_row(kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + uint32_t size) { + const static auto spirv = glsl_compile_source(program_source_head+ + "#define IN_TYPE float\n" + "#define OUT_TYPE float16_t\n"+ + program_fpx_to_fpx, __func__); + + struct PushConstants { + uint32_t inOff, outOff; + } const pushConsts { + inOff, outOff + }; + + seq.record(mgr.algorithm({in, out}, spirv, {size}, {}, {pushConsts})); +} + + static const std::string program_abmath = MULTILINE_QUOTE( layout(push_constant) uniform PushConstants { @@ -535,24 +574,24 @@ void main() { const uint out_off = pcs.outOff + extra_off; // parallel max - buf[gl_LocalInvocationID.x] = uintBitsToFloat(0xFF800000); - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { - buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], in_[in_off + i00]); + buf[gl_WorkGroupID.x] = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) { + buf[gl_WorkGroupID.x] = max(buf[gl_WorkGroupID.x], in_[in_off + i00]); } // reduce barrier(); memoryBarrierShared(); - for (uint i = nth/2; i > 0; i /= 2) { - if (gl_LocalInvocationID.x < i) { - buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], buf[gl_LocalInvocationID.x + i]); + [[unroll]] for (uint i = nth/2; i > 0; i /= 2) { + if (gl_WorkGroupID.x < i) { + buf[gl_WorkGroupID.x] = max(buf[gl_WorkGroupID.x], buf[gl_WorkGroupID.x + i]); } barrier(); memoryBarrierShared(); } // broadcast (no effect?) - if (gl_LocalInvocationID.x == 0) { + if (gl_WorkGroupID.x == 0) { buf[0] = buf[0]; // ??? } @@ -562,24 +601,24 @@ void main() { const float max_ = buf[0]; // parallel sum - buf[gl_LocalInvocationID.x] = 0.0; - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { - buf[gl_LocalInvocationID.x] += exp(in_[in_off + i00] - max_); + buf[gl_WorkGroupID.x] = 0.0; + for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) { + buf[gl_WorkGroupID.x] += exp(in_[in_off + i00] - max_); } // reduce barrier(); memoryBarrierShared(); for (uint i = nth/2; i > 0; i /= 2) { - if (gl_LocalInvocationID.x < i) { - buf[gl_LocalInvocationID.x] += buf[gl_LocalInvocationID.x + i]; + if (gl_WorkGroupID.x < i) { + buf[gl_WorkGroupID.x] += buf[gl_WorkGroupID.x + i]; } barrier(); memoryBarrierShared(); } // broadcast (no effect?) - if (gl_LocalInvocationID.x == 0) { + if (gl_WorkGroupID.x == 0) { buf[0] = buf[0]; // ??? } @@ -588,7 +627,7 @@ void main() { const float sum = buf[0]; - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) { + for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) { out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum; } } @@ -612,6 +651,285 @@ void ggml_vk_soft_max(kp::Sequence& seq, } +static const std::string program_mul_mat_f16 = R"( +#define BM 128 +#define BN 128 +#define BK 8 +#define TM 8 +#define TN 8 +)" MULTILINE_QUOTE( +layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + int M; + int N; + int K; + int inAStride; + int inBStride; + int outStride; + uint inAOff; + uint inBOff; + uint outOff; +} pcs; + +shared float16_t bufA[BM * (BK+1)]; +shared float16_t bufB[BN * (BK+1)]; + +void main() { + const int ir = int(gl_WorkGroupID.x); + const int ic = int(gl_WorkGroupID.y); + + const int rstride = BM / TM; + + const int lr = int(gl_LocalInvocationID.x % rstride); + const int lc = int(gl_LocalInvocationID.x / rstride); + + const int loadr = int(gl_LocalInvocationID.x % BK); + const int loadc = int(gl_LocalInvocationID.x / BK); + + const int loadstride = int(gl_WorkGroupSize.x); + + int posA = ir * BM * pcs.inAStride; + int posB = ic * BN * pcs.inBStride; + + float sums[TM * TN]; + float16_t cacheA[TM]; + float16_t cacheB[TN]; + + [[unroll]] for (int i = 0; i < TM*TN; i++) { + sums[i] = 0.0hf; + } + + [[unroll]] for (int block = 0; block < pcs.K; block += BK) { + [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr]; + } + [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr]; + } + + barrier(); + + posA += BK; + posB += BK; + + [[unroll]] for (int i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (int j = 0; j < BM; j++) { + cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i]; + } + [[unroll]] for (int j = 0; j < TN; j++) { + cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i]; + } + + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[cc * TM + cr] += float(cacheA[cr]) * float(cacheB[cc]); + } + } + } + + barrier(); + } + + const int dr = ir * BM + lr; + const int dc = ic * BN + lc * TN; + + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + out_[(dc + cc) * pcs.outStride + dr + cr*rstride] = sums[cc * TM + cr]; + } + } +} +); + +void ggml_vk_mul_mat_f16(kp::Sequence& seq, + const std::shared_ptr& inA, uint32_t inAOff, + const std::shared_ptr& inB, uint32_t inBOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03, + int64_t ne10, int64_t ne11, + int nb10, int nb11, int nb12, int nb13, + int nb2, int nb3) { + const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__); + + const bool inB_cont_rows = nb10 == sizeof(float); + const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float); + + struct PushConstants { + int32_t M, N, K, inAStride, inBStride, outStride; + uint32_t inAOff, inBOff, outOff; + } pushConsts { + (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01, + inAOff, inBOff, outOff + }; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + auto tmp = mgr.tensorT(std::vector(ne10*ne11)); + + if (inB_cont_rows) { + if (inB_cont_cols) { + ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12)/sizeof(float), tmp, 0, ne10*ne11); + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11)/sizeof(float), tmp, i01*ne10, ne10); + } + } + } else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + for (int64_t i00 = 0; i00 < ne10; i00++) { + // Extremely slow because of single shader invocation + ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)/sizeof(float), tmp, i01*ne10 + i00, 1); + } + } + } + + seq.record(mgr.algorithm({inA, tmp, out}, spirv, {(uint32_t)ne01, (uint32_t)ne11}, {}, {pushConsts})); + } + } +} + + +static const std::string program_mul_mat_f32 = R"( +#define BM 128 +#define BN 128 +#define BK 8 +#define TM 8 +#define TN 8 +)" MULTILINE_QUOTE( +layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { float inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + int M; + int N; + int K; + int inAStride; + int inBStride; + int outStride; + uint inAOff; + uint inBOff; + uint outOff; +} pcs; + +shared float bufA[BM * (BK+1)]; +shared float bufB[BN * (BK+1)]; + +void main() { + const int ir = int(gl_WorkGroupID.x); + const int ic = int(gl_WorkGroupID.y); + + const int rstride = BM / TM; + + const int lr = int(gl_WorkGroupID.x % rstride); + const int lc = int(gl_WorkGroupID.x / rstride); + + const int loadr = int(gl_WorkGroupID.x % BK); + const int loadc = int(gl_WorkGroupID.x / BK); + + const int loadstride = int(gl_WorkGroupSize.x); + + int posA = ir * BM * pcs.inAStride; + int posB = ic * BN * pcs.inBStride; + + float sums[TM * TN]; + float cacheA[TM]; + float cacheB[TN]; + + [[unroll]] for (int i = 0; i < TM*TN; i++) { + sums[i] = 0.0f; + } + + [[unroll]] for (int block = 0; block < pcs.K; block += BK) { + [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr + pcs.inAOff]; + } + [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { + const int lr = l % BK; + const int lc = l / BK; + bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr + pcs.inBOff]; + } + + barrier(); + memoryBarrierShared(); + + posA += BK; + posB += BK; + + [[unroll]] for (int i = 0; i < BK; i++) { + // Load from shared into cache + [[unroll]] for (int j = 0; j < BM; j++) { + cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i]; + } + [[unroll]] for (int j = 0; j < TN; j++) { + cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i]; + } + + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[cc * TM + cr] += cacheA[cr] * cacheB[cc]; + } + } + } + + barrier(); + } + + const int dr = ir * BM + lr; + const int dc = ic * BN + lc * TN; + + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + out_[(dc + cc) * pcs.outStride + dr + cr*rstride + pcs.outOff] = sums[cc * TM + cr]; + } + } +} +); + +void ggml_vk_mul_mat_f32(kp::Sequence& seq, + const std::shared_ptr& inA, uint32_t inAOff, + const std::shared_ptr& inB, uint32_t inBOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03, + int64_t ne10, int64_t ne11, + int nb2, int nb3) { + const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f32, __func__); + + struct PushConstants { + int32_t M, N, K, inAStride, inBStride, outStride; + uint32_t inAOff, inBOff, outOff; + } pushConsts { + (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01, + inAOff, inBOff, outOff + }; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + auto off = i02*nb2 + i03*nb3; + pushConsts.inAOff = inAOff + off; + pushConsts.inBOff = inBOff + off; + pushConsts.outOff = outOff + off; + seq.record(mgr.algorithm({inA, inB, out}, spirv, {(uint32_t)ne01, (uint32_t)ne11}, {}, {pushConsts})); + } + } +} + + void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { printf("%s: evaluating graph\n", __func__); @@ -723,6 +1041,17 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03); } break; + case GGML_OP_MUL_MAT: + { + if (src0->type == GGML_TYPE_F32 + && src1->type == GGML_TYPE_F32) { + ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3); + break; + } else if (src0->type == GGML_TYPE_F32 + && src1->type == GGML_TYPE_F16) { + ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3); + } + } default: fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); //GGML_ASSERT(false);