diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b7434da3e..c247b50c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2008,8 +2008,10 @@ static void ggml_metal_encode_node( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 4; + const int ne11_mm_min = 4; + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && ( ( @@ -2033,12 +2035,20 @@ static void ggml_metal_encode_node( ) ) { // TODO: determine the optimal parameters based on grid utilization - const int nsg = 2; // TODO: or 4? - const int nxpsg = ne11 < 3 ? 16 : 8; - const int nypsg = 32/nxpsg; - const int r0ptg = nypsg*nsg; - int r1ptg = 4; + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int r1ptg = 4; // num src1 rows per threadgroup + // note: not sure how optimal are those across all different hardware. there might be someting cleverer switch (ne11) { case 2: r1ptg = 2; break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3d15897f4..7567f3262 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +// mat-vec kernel processing in chunks of float4 +// chpb - chunks per quantization block template void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, @@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short chpt = 4; + const short chpt = 4; // chunks per thread //const short nxpsg = (32); const short nypsg = (32/nxpsg); @@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl( float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; - short cch = tx%chpb; + short cch = tx%chpb; // current chunk index for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { float4 lx[chpt]; @@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } } + // reduce only the threads in each row for (short ir1 = 0; ir1 < r1ptg; ++ir1) { if (nxpsg >= 32) { sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); @@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } } +// mat-vec kernel processing in chunks of float4x4 template void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, @@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl( } } +// dispatchers needed for compile-time nxpsg +// epb - elements per quantization block template kernel void kernel_mul_mv_ext_q4_f32_disp( constant ggml_metal_kargs_mul_mv_ext & args,