metal : use F32 attention accumulators
This commit is contained in:
parent
fa9e8c6689
commit
c16a7c2688
3 changed files with 81 additions and 93 deletions
15
ggml-metal.m
15
ggml-metal.m
|
@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
|
||||
int64_t nsg = 1;
|
||||
|
@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
nsg /= 2;
|
||||
|
||||
// require power of 2
|
||||
//{
|
||||
// int64_t nsgm = 1;
|
||||
// while (nsgm < nsg) {
|
||||
// nsgm *= 2;
|
||||
// }
|
||||
// GGML_ASSERT(nsg == nsgm);
|
||||
//}
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue