metal : support Q > 8

This commit is contained in:
Georgi Gerganov 2024-01-28 23:08:31 +02:00
parent 134c81c78d
commit 1db22d7032
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 55 additions and 34 deletions

View file

@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8)
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32)
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)