diff --git a/CMakeLists.txt b/CMakeLists.txt index c3533b969..0b9b4f023 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -464,6 +464,7 @@ if (LLAMA_KOMPUTE) DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source} ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp + ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} COMMENT "Compiling ${source} to ${spv_file}" diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 45a579b3b..8a9e415e1 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1003,32 +1003,40 @@ static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq, seq.record(s_algo); } -static void ggml_vk_mul_mat_q4_x( +static void ggml_vk_mul_mat_impl( const std::vector& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq, const std::shared_ptr& inA, const std::shared_ptr& inB, const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, - int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02 + int32_t ne00, int32_t ne01, int32_t ne02, + int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, + int32_t ne0, int32_t ne1, + uint32_t r2, uint32_t r3 ) { struct PushConstants { uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne10, ne0, ne1, ne01, gqa; + int32_t ne00, ne01, ne02; + int32_t ne10, ne12; + int32_t ne0, ne1; + uint32_t r2, r3; } pushConsts { safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne10, ne0, ne1, ne01, ne12/ne02 + ne00, ne01, ne02, + ne10, ne12, + ne0, ne1, + r2, r3 }; auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(name)) { const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts}); + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(name); s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12)}); + s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); } @@ -1040,7 +1048,7 @@ static void ggml_vk_mul_mat_q4_0(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv, kp::shader_data::op_mul_mat_q4_0_comp_spv_len); - ggml_vk_mul_mat_q4_x(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward(args)...); + ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward(args)...); } template @@ -1048,16 +1056,18 @@ static void ggml_vk_mul_mat_q4_1(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv, kp::shader_data::op_mul_mat_q4_1_comp_spv_len); - ggml_vk_mul_mat_q4_x(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward(args)...); + ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward(args)...); } -static void ggml_vk_mul_mat_q6_k(kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, - int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02) { +static void ggml_vk_mul_mat_q6_k( + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, + int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02 +) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv, kp::shader_data::op_mul_mat_q6_k_comp_spv_len); @@ -1550,6 +1560,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_MUL_MAT: { + GGML_ASSERT(ne00 == ne10); + + // TODO: assert that dim2 and dim3 are contiguous + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint32_t r2 = ne12/ne02; + const uint32_t r3 = ne13/ne03; + if (src1t != GGML_TYPE_F32) { fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); goto not_implemented; @@ -1563,29 +1582,40 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph switch (src0t) { case GGML_TYPE_F32: - ggml_vk_mul_mat_mat_f32(seq, - id_src0, id_src1, id_dst, - off_src0, off_src1, off_dst, - ne00, ne01, ne02, - nb01, nb02, - ne11, ne12, - nb11, nb12, - nb1, nb2); + ggml_vk_mul_mat_mat_f32( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2 + ); break; case GGML_TYPE_F16: - ggml_vk_mul_mat_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1); + ggml_vk_mul_mat_f16( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1 + ); break; case GGML_TYPE_Q8_0: - ggml_vk_mul_mat_q8_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1); + ggml_vk_mul_mat_q8_0( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, nb01, nb02, ne11, ne12, nb11, nb12, ne0, ne1 + ); break; case GGML_TYPE_Q4_0: - ggml_vk_mul_mat_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02); + ggml_vk_mul_mat_q4_0( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ); break; case GGML_TYPE_Q4_1: - ggml_vk_mul_mat_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02); + ggml_vk_mul_mat_q4_1( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ); break; case GGML_TYPE_Q6_K: - ggml_vk_mul_mat_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02); + ggml_vk_mul_mat_q6_k( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02 + ); break; default: { fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); diff --git a/kompute-shaders/op_mul_mat_q4_0.comp b/kompute-shaders/op_mul_mat_q4_0.comp index 03788c920..b0cea8bbe 100644 --- a/kompute-shaders/op_mul_mat_q4_0.comp +++ b/kompute-shaders/op_mul_mat_q4_0.comp @@ -6,25 +6,7 @@ #define SIZE_OF_BLOCK sizeof_block_q4_0 #define N_ROWS 4 -layout(local_size_x_id = 0) in; -layout(local_size_y = 1) in; -layout(local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne10; - int ne0; - int ne1; - int ne01; - int gqa; -} pcs; +#include "op_mul_mv_q_n_pre.comp" // The q4_0 version of this function float block_q_n_dot_y(uint block_index, uint yb, uint il) { diff --git a/kompute-shaders/op_mul_mat_q4_1.comp b/kompute-shaders/op_mul_mat_q4_1.comp index 0ae8f8c7d..8582c61a3 100644 --- a/kompute-shaders/op_mul_mat_q4_1.comp +++ b/kompute-shaders/op_mul_mat_q4_1.comp @@ -6,25 +6,7 @@ #define SIZE_OF_BLOCK sizeof_block_q4_1 #define N_ROWS 4 -layout(local_size_x_id = 0) in; -layout(local_size_y = 1) in; -layout(local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne10; - int ne0; - int ne1; - int ne01; - int gqa; -} pcs; +#include "op_mul_mv_q_n_pre.comp" // The q4_1 version of this function float block_q_n_dot_y(uint block_index, uint yb, uint il) { diff --git a/kompute-shaders/op_mul_mv_q_n.comp b/kompute-shaders/op_mul_mv_q_n.comp index 8b6e6a2e2..440b5ab2c 100644 --- a/kompute-shaders/op_mul_mv_q_n.comp +++ b/kompute-shaders/op_mul_mv_q_n.comp @@ -1,13 +1,20 @@ void main() { + // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 if (gl_SubgroupInvocationID > 31) return; const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); + const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; const uint im = gl_WorkGroupID.z; + const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS; - const uint offset0 = first_row * nb + im/pcs.gqa*(nb*pcs.ne0); + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); const uint x = offset0; // Based from inA without base offset const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB diff --git a/kompute-shaders/op_mul_mv_q_n_pre.comp b/kompute-shaders/op_mul_mv_q_n_pre.comp new file mode 100644 index 000000000..7912b09ac --- /dev/null +++ b/kompute-shaders/op_mul_mv_q_n_pre.comp @@ -0,0 +1,22 @@ +layout(local_size_x_id = 0) in; +layout(local_size_y = 1) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + int ne10; + int ne12; + int ne0; + int ne1; + uint r2; + uint r3; +} pcs;