metal : warp-based reduction for soft max kernel
This commit is contained in:
parent
68e02c0d58
commit
55717c98c4
2 changed files with 78 additions and 71 deletions
10
ggml-metal.m
10
ggml-metal.m
|
@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
|
|||
int nth = 32; // SIMD width
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||
} else {
|
||||
do {
|
||||
while (nth < ne00 && nth < 1024) {
|
||||
nth *= 2;
|
||||
} while (nth <= ne00 && nth <= 1024);
|
||||
nth /= 2;
|
||||
}
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||
}
|
||||
|
||||
|
@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue