metal : efficient flash_attn_f16 implementation

This commit is contained in:
Georgi Gerganov 2024-01-23 18:27:54 +02:00
parent 17720fad66
commit 1446a12b29
No known key found for this signature in database
GPG key ID: BF970631944C16B7
3 changed files with 195 additions and 118 deletions

View file

@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute(
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape(src1, src2));
GGML_ASSERT(src3);
size_t offs_src2 = 0;
size_t offs_src3 = 0;
@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 32;
const int64_t nhptg = 2; // heads per threadgroup
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps)
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY: