metal : parallel reduce across heads

This commit is contained in:
Georgi Gerganov 2024-01-21 22:44:41 +02:00
parent 77d08f3272
commit 17720fad66
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 22 additions and 14 deletions

View file

@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 16;
const int64_t nhptg = 4; // heads per threadgroup
const int64_t nwarps = 32;
const int64_t nhptg = 2; // heads per threadgroup
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);