metal : warp-based reduction for soft max kernel

This commit is contained in:
Georgi Gerganov 2023-11-30 21:52:32 +02:00
parent 68e02c0d58
commit 55717c98c4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 78 additions and 71 deletions

View file

@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
int nth = 32; // SIMD width
if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
do {
while (nth < ne00 && nth < 1024) {
nth *= 2;
} while (nth <= ne00 && nth <= 1024);
nth /= 2;
}
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;