metal : parallelize across KV size

This commit is contained in:
Georgi Gerganov 2024-01-21 21:04:15 +02:00
parent a4b6341c7b
commit 77d08f3272
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 52 additions and 93 deletions

View file

@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 8;
const int64_t nhpw = 4; // heads per warp
const int64_t nwarps = 16;
const int64_t nhptg = 4; // heads per threadgroup
const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2);
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY: