metal : add parallel reduce version (disabled)

This commit is contained in:
Georgi Gerganov 2024-01-25 17:59:41 +02:00
parent f9ca5dcbe8
commit 6fea843b24
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 42 additions and 2 deletions

View file

@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps)
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)