metal : use F32 attention accumulators

This commit is contained in:
Georgi Gerganov 2024-04-18 20:08:52 +03:00
parent fa9e8c6689
commit c16a7c2688
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 81 additions and 93 deletions

View file

@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
int64_t nsg = 1;
@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
}
nsg /= 2;
// require power of 2
//{
// int64_t nsgm = 1;
// while (nsgm < nsg) {
// nsgm *= 2;
// }
// GGML_ASSERT(nsg == nsgm);
//}
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);