mtl : more threads for rms_norm + better timing

This commit is contained in:
Georgi Gerganov 2023-06-02 19:26:58 +03:00
parent 70c3387726
commit b088e14a7e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 33 additions and 21 deletions

View file

@ -46,10 +46,21 @@ int main(int argc, char ** argv) {
const std::vector<int> tmp(n_batch, 1); // BOS const std::vector<int> tmp(n_batch, 1); // BOS
// warmup
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
const int n_iter = 16;
const int64_t t0 = ggml_time_us();
// the actual inference happens here // the actual inference happens here
for (int i = 0; i < 10; ++i) { for (int i = 0; i < n_iter; ++i) {
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past); llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
} }
const int64_t t1 = ggml_time_us();
printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
} }
llama_mtl_free(ctx_mtl); llama_mtl_free(ctx_mtl);

View file

@ -492,9 +492,9 @@ int llama_mtl_eval(
const uint64_t nb11 = gf->nodes[i]->src1->nb[1]; const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
const uint64_t nb12 = gf->nodes[i]->src1->nb[2]; const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
const int64_t ne0 = gf->nodes[i]->ne[0]; const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1]; const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2]; const int64_t ne2 = gf->nodes[i]->ne[2];
const uint64_t nb0 = gf->nodes[i]->nb[0]; const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1]; const uint64_t nb1 = gf->nodes[i]->nb[1];
@ -515,6 +515,7 @@ int llama_mtl_eval(
if (ggml_is_contiguous(gf->nodes[i]->src0) && if (ggml_is_contiguous(gf->nodes[i]->src0) &&
ggml_is_contiguous(gf->nodes[i]->src1) && ggml_is_contiguous(gf->nodes[i]->src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) { if (encoder != nil) {
[encoder endEncoding]; [encoder endEncoding];
encoder = nil; encoder = nil;
@ -649,7 +650,7 @@ int llama_mtl_eval(
const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f; const float eps = 1e-6f;
const int nth = 32; const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -779,22 +780,22 @@ int llama_mtl_eval(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;