metal : add comments

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-03 11:39:33 +02:00
parent 5590160cd6
commit 434fc452c3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 24 additions and 8 deletions

View file

@ -2008,8 +2008,10 @@ static void ggml_metal_encode_node(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = 4; const int ne11_mm_min = 4;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
( (
( (
@ -2033,12 +2035,20 @@ static void ggml_metal_encode_node(
) )
) { ) {
// TODO: determine the optimal parameters based on grid utilization // TODO: determine the optimal parameters based on grid utilization
const int nsg = 2; // TODO: or 4? // I still don't know why we should not always use the maximum available threads:
const int nxpsg = ne11 < 3 ? 16 : 8; //
const int nypsg = 32/nxpsg; // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
const int r0ptg = nypsg*nsg; //
int r1ptg = 4; // my current hypothesis is that the work grid is not evenly divisible for different nsg
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
switch (ne11) { switch (ne11) {
case 2: case 2:
r1ptg = 2; break; r1ptg = 2; break;

View file

@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
// mat-vec kernel processing in chunks of float4
// chpb - chunks per quantization block
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) > template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
void kernel_mul_mv_ext_q4_f32_impl( void kernel_mul_mv_ext_q4_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args, constant ggml_metal_kargs_mul_mv_ext & args,
@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short chpt = 4; const short chpt = 4; // chunks per thread
//const short nxpsg = (32); //const short nxpsg = (32);
const short nypsg = (32/nxpsg); const short nypsg = (32/nxpsg);
@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
short cch = tx%chpb; short cch = tx%chpb; // current chunk index
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
float4 lx[chpt]; float4 lx[chpt];
@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
} }
} }
// reduce only the threads in each row
for (short ir1 = 0; ir1 < r1ptg; ++ir1) { for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
if (nxpsg >= 32) { if (nxpsg >= 32) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
} }
} }
// mat-vec kernel processing in chunks of float4x4
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) > template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
void kernel_mul_mv_ext_q4x4_f32_impl( void kernel_mul_mv_ext_q4x4_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args, constant ggml_metal_kargs_mul_mv_ext & args,
@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
} }
} }
// dispatchers needed for compile-time nxpsg
// epb - elements per quantization block
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)> template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
kernel void kernel_mul_mv_ext_q4_f32_disp( kernel void kernel_mul_mv_ext_q4_f32_disp(
constant ggml_metal_kargs_mul_mv_ext & args, constant ggml_metal_kargs_mul_mv_ext & args,