metal : add comments
ggml-ci
This commit is contained in:
parent
5590160cd6
commit
434fc452c3
2 changed files with 24 additions and 8 deletions
|
@ -2008,8 +2008,10 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 4;
|
const int ne11_mm_min = 4;
|
||||||
|
|
||||||
|
// first try to use small-batch mat-mv kernels
|
||||||
|
// these should be efficient for BS [2, ~8]
|
||||||
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
|
@ -2033,12 +2035,20 @@ static void ggml_metal_encode_node(
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
// TODO: determine the optimal parameters based on grid utilization
|
// TODO: determine the optimal parameters based on grid utilization
|
||||||
const int nsg = 2; // TODO: or 4?
|
// I still don't know why we should not always use the maximum available threads:
|
||||||
const int nxpsg = ne11 < 3 ? 16 : 8;
|
//
|
||||||
const int nypsg = 32/nxpsg;
|
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
|
||||||
const int r0ptg = nypsg*nsg;
|
//
|
||||||
int r1ptg = 4;
|
// my current hypothesis is that the work grid is not evenly divisible for different nsg
|
||||||
|
// values and there can be some tail effects when nsg is high. need to confirm this
|
||||||
|
//
|
||||||
|
const int nsg = 2; // num simdgroups per threadgroup
|
||||||
|
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
||||||
|
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
||||||
|
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
||||||
|
int r1ptg = 4; // num src1 rows per threadgroup
|
||||||
|
|
||||||
|
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
|
||||||
switch (ne11) {
|
switch (ne11) {
|
||||||
case 2:
|
case 2:
|
||||||
r1ptg = 2; break;
|
r1ptg = 2; break;
|
||||||
|
|
|
@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mat-vec kernel processing in chunks of float4
|
||||||
|
// chpb - chunks per quantization block
|
||||||
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
||||||
void kernel_mul_mv_ext_q4_f32_impl(
|
void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short chpt = 4;
|
const short chpt = 4; // chunks per thread
|
||||||
|
|
||||||
//const short nxpsg = (32);
|
//const short nxpsg = (32);
|
||||||
const short nypsg = (32/nxpsg);
|
const short nypsg = (32/nxpsg);
|
||||||
|
@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
|
|
||||||
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
||||||
|
|
||||||
short cch = tx%chpb;
|
short cch = tx%chpb; // current chunk index
|
||||||
|
|
||||||
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
|
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
|
||||||
float4 lx[chpt];
|
float4 lx[chpt];
|
||||||
|
@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reduce only the threads in each row
|
||||||
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
if (nxpsg >= 32) {
|
if (nxpsg >= 32) {
|
||||||
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
||||||
|
@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mat-vec kernel processing in chunks of float4x4
|
||||||
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
||||||
void kernel_mul_mv_ext_q4x4_f32_impl(
|
void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dispatchers needed for compile-time nxpsg
|
||||||
|
// epb - elements per quantization block
|
||||||
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
|
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
|
||||||
kernel void kernel_mul_mv_ext_q4_f32_disp(
|
kernel void kernel_mul_mv_ext_q4_f32_disp(
|
||||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue