metal : fix shared memory calc + reduce smem + comments
This commit is contained in:
parent
1e129611b1
commit
d805404e2d
2 changed files with 40 additions and 14 deletions
|
@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// 16*32*nsgmax
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
while (true) {
|
||||
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
|
@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
|
|||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
|
@ -3178,8 +3181,28 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the attention scores (nqptg x ncpsg) as f32
|
||||
//
|
||||
// 2*ne00*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
}
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
|
@ -3187,12 +3210,12 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
|
|
|
@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext(
|
|||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<block_q, half4x4>::value) {
|
||||
// we can read directly from global memory
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext(
|
|||
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
half4x4 tmp;
|
||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
|
@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext(
|
|||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
// no need for bound checks
|
||||
half4x4 tmp;
|
||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||
skv4[4*ty + tx] = tmp;
|
||||
|
@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short NW4 = NW/4;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
const short SH = C; // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue