metal : warp-based reduce for rms_norm

This commit is contained in:
Georgi Gerganov 2023-11-30 22:21:30 +02:00
parent 55717c98c4
commit c4db59230d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 28 additions and 35 deletions

View file

@ -1358,15 +1358,19 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = MIN(512, ne00);
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < 1024) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);