wip
This commit is contained in:
parent
af3eda9c77
commit
6ccbd1777a
2 changed files with 117 additions and 128 deletions
|
@ -2253,13 +2253,15 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nhptg = 4; // heads per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t nsg = 16; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t nqptg = 2; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 8;
|
||||
|
||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(nhptg*ne00 + nsg*(256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(nhptg*ne00 + nsg*(32*ncpsg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue