From b088e14a7e03104d9e4c027c34b4b7b8b37a124c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Jun 2023 19:26:58 +0300 Subject: [PATCH] mtl : more threads for rms_norm + better timing --- examples/mtl/mtl.cpp | 13 ++++++++++++- examples/mtl/mtl.m | 41 +++++++++++++++++++++-------------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/examples/mtl/mtl.cpp b/examples/mtl/mtl.cpp index 7f52453d8..b7b84cecf 100644 --- a/examples/mtl/mtl.cpp +++ b/examples/mtl/mtl.cpp @@ -46,10 +46,21 @@ int main(int argc, char ** argv) { const std::vector tmp(n_batch, 1); // BOS + // warmup + llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); + + const int n_iter = 16; + + const int64_t t0 = ggml_time_us(); + // the actual inference happens here - for (int i = 0; i < 10; ++i) { + for (int i = 0; i < n_iter; ++i) { llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); } + + const int64_t t1 = ggml_time_us(); + + printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter); } llama_mtl_free(ctx_mtl); diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index c74c28cd9..2eb874884 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -492,9 +492,9 @@ int llama_mtl_eval( const uint64_t nb11 = gf->nodes[i]->src1->nb[1]; const uint64_t nb12 = gf->nodes[i]->src1->nb[2]; - const int64_t ne0 = gf->nodes[i]->ne[0]; - const int64_t ne1 = gf->nodes[i]->ne[1]; - const int64_t ne2 = gf->nodes[i]->ne[2]; + const int64_t ne0 = gf->nodes[i]->ne[0]; + const int64_t ne1 = gf->nodes[i]->ne[1]; + const int64_t ne2 = gf->nodes[i]->ne[2]; const uint64_t nb0 = gf->nodes[i]->nb[0]; const uint64_t nb1 = gf->nodes[i]->nb[1]; @@ -515,6 +515,7 @@ int llama_mtl_eval( if (ggml_is_contiguous(gf->nodes[i]->src0) && ggml_is_contiguous(gf->nodes[i]->src1) && (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { + if (encoder != nil) { [encoder endEncoding]; encoder = nil; @@ -649,7 +650,7 @@ int llama_mtl_eval( const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; const float eps = 1e-6f; - const int nth = 32; + const int nth = 256; [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -779,22 +780,22 @@ int llama_mtl_eval( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break;