metal : fix shared memory calc + reduce smem + comments

This commit is contained in:
Georgi Gerganov 2024-11-05 08:13:48 +02:00
parent 1e129611b1
commit d805404e2d
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 40 additions and 14 deletions

View file

@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// 16*32*nsgmax
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half4x4 kernel
@ -3178,8 +3181,28 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32
//
// 2*ne00*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@ -3187,12 +3210,12 @@ static void ggml_metal_encode_node(
}
nsg /= 2;
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;

View file

@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext(
for (short cc = 0; cc < C/8; ++cc) {
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
// this is compile-time check, so it does not have runtime overhead
if (is_same<block_q, half4x4>::value) {
// we can read directly from global memory
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext(
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
half4x4 tmp;
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext(
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
if (D16%4 == 0) {
// no need for bound checks
half4x4 tmp;
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
skv4[4*ty + tx] = tmp;
@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec(
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NW4 = NW/4;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short SH = C; // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results