metal : fix shared memory calc + reduce smem + comments
This commit is contained in:
parent
1e129611b1
commit
d805404e2d
2 changed files with 40 additions and 14 deletions
|
@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// 16*32*nsgmax
|
||||||
|
// the shared memory needed for the simdgroups to load the KV cache
|
||||||
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||||
|
//
|
||||||
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
|
const size_t smem = FATTN_SMEM(nsgmax);
|
||||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
||||||
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
|
|
||||||
if (smem > device.maxThreadgroupMemoryLength) {
|
if (smem > device.maxThreadgroupMemoryLength) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||||
|
|
||||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
|
const size_t smem = FATTN_SMEM(nsg);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||||
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
#undef FATTN_SMEM
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
} else {
|
} else {
|
||||||
// half4x4 kernel
|
// half4x4 kernel
|
||||||
|
@ -3178,8 +3181,28 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(nqptg % 1 == 0);
|
GGML_ASSERT(nqptg % 1 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// ne00 + 2*ncpsg*(nsg)
|
||||||
|
// for each query, we load it as f16 in shared memory (ne00)
|
||||||
|
// and store the attention scores (nqptg x ncpsg) as f32
|
||||||
|
//
|
||||||
|
// 2*ne00*(nsg)
|
||||||
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||||
|
//
|
||||||
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const size_t smem = FATTN_SMEM(nsgmax);
|
||||||
|
if (smem > device.maxThreadgroupMemoryLength) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nsgmax *= 2;
|
||||||
|
}
|
||||||
|
nsgmax /= 2;
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||||
|
|
||||||
int64_t nsg = 1;
|
int64_t nsg = 1;
|
||||||
while (nsg <= nsgt) {
|
while (nsg <= nsgt) {
|
||||||
|
@ -3187,12 +3210,12 @@ static void ggml_metal_encode_node(
|
||||||
}
|
}
|
||||||
nsg /= 2;
|
nsg /= 2;
|
||||||
|
|
||||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
|
const size_t smem = FATTN_SMEM(nsg);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
|
#undef FATTN_SMEM
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||||
|
|
||||||
|
// this is compile-time check, so it does not have runtime overhead
|
||||||
if (is_same<block_q, half4x4>::value) {
|
if (is_same<block_q, half4x4>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||||
half4x4 tmp;
|
half4x4 tmp;
|
||||||
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
skv4[4*ty + tx] = tmp;
|
||||||
|
@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
if (D16%4 == 0) {
|
if (D16%4 == 0) {
|
||||||
|
// no need for bound checks
|
||||||
half4x4 tmp;
|
half4x4 tmp;
|
||||||
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
|
||||||
skv4[4*ty + tx] = tmp;
|
skv4[4*ty + tx] = tmp;
|
||||||
|
@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
const short D16 = D/16;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short NW4 = NW/4;
|
const short NW4 = NW/4;
|
||||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
const short SH = C; // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
|
||||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
|
||||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||||
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue