metal : warp-based reduce for rms_norm
This commit is contained in:
parent
55717c98c4
commit
c4db59230d
2 changed files with 28 additions and 35 deletions
18
ggml-metal.m
18
ggml-metal.m
|
@ -1358,15 +1358,19 @@ void ggml_metal_graph_compute(
|
|||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int nth = MIN(512, ne00);
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < 1024) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue