wip : 8 rows per simd group
This commit is contained in:
parent
b97325800a
commit
8cde449b8b
2 changed files with 140 additions and 45 deletions
10
ggml-metal.m
10
ggml-metal.m
|
@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int nwarps = 1;
|
||||
const int64_t nwarps = 2;
|
||||
|
||||
const size_t shalf = sizeof(float)/2;
|
||||
const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2);
|
||||
|
||||
GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0];
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue