diff --git a/ggml/src/ggml-kompute/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp index 4a0e48990..a8adcfb5b 100644 --- a/ggml/src/ggml-kompute/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute/ggml-kompute.cpp @@ -1018,6 +1018,8 @@ static void ggml_vk_mul_mat_impl( int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0, int32_t ne1, + uint32_t nb01, uint32_t nb02, uint32_t nb03, + uint32_t nb11, uint32_t nb12, uint32_t nb13, uint32_t r2, uint32_t r3 ) { struct PushConstants { @@ -1025,19 +1027,23 @@ static void ggml_vk_mul_mat_impl( int32_t ne00, ne01, ne02; int32_t ne10, ne12; int32_t ne0, ne1; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; uint32_t r2, r3; } pushConsts { safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4), ne00, ne01, ne02, ne10, ne12, ne0, ne1, + nb01, nb02, nb03, + nb11, nb12, nb13, r2, r3 }; auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(name)) { - const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; + const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8; s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(name); @@ -1694,19 +1700,22 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml case GGML_TYPE_Q8_0: ggml_vk_mul_mat_q8_0( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q4_0: ggml_vk_mul_mat_q4_0( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q4_1: ggml_vk_mul_mat_q4_1( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp index 440b5ab2c..a6517cc1f 100644 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp @@ -14,10 +14,15 @@ void main() { const uint i12 = im%pcs.ne12; const uint i13 = im/pcs.ne12; - const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); + // pointers to src0 rows + uint ax[N_ROWS]; + for (int row = 0; row < N_ROWS; ++row) { + const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); - const uint x = offset0; // Based from inA without base offset - const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB + ax[row] = offset0 + pcs.inAOff; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; @@ -32,8 +37,7 @@ void main() { for (uint ib = ix; ib < nb; ib += 16) { for (int row = 0; row < N_ROWS; row++) { - const uint block_index = x + ib + row * nb; - sumf[row] += block_q_n_dot_y(block_index, yb, il); + sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il); } yb += BLOCKS_IN_QUANT * 16; diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp index 7912b09ac..a9a2f2218 100644 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp @@ -1,5 +1,5 @@ layout(local_size_x_id = 0) in; -layout(local_size_y = 1) in; +layout(local_size_y = 8) in; layout(local_size_z = 1) in; layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; @@ -17,6 +17,12 @@ layout (push_constant) uniform parameter { int ne12; int ne0; int ne1; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; uint r2; uint r3; } pcs;