mtl : more threads for rms_norm + better timing
This commit is contained in:
parent
70c3387726
commit
b088e14a7e
2 changed files with 33 additions and 21 deletions
|
@ -46,10 +46,21 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const std::vector<int> tmp(n_batch, 1); // BOS
|
const std::vector<int> tmp(n_batch, 1); // BOS
|
||||||
|
|
||||||
|
// warmup
|
||||||
|
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
||||||
|
|
||||||
|
const int n_iter = 16;
|
||||||
|
|
||||||
|
const int64_t t0 = ggml_time_us();
|
||||||
|
|
||||||
// the actual inference happens here
|
// the actual inference happens here
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < n_iter; ++i) {
|
||||||
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t t1 = ggml_time_us();
|
||||||
|
|
||||||
|
printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_mtl_free(ctx_mtl);
|
llama_mtl_free(ctx_mtl);
|
||||||
|
|
|
@ -492,9 +492,9 @@ int llama_mtl_eval(
|
||||||
const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
|
const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
|
||||||
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
|
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
|
||||||
|
|
||||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
const int64_t ne0 = gf->nodes[i]->ne[0];
|
||||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
const int64_t ne1 = gf->nodes[i]->ne[1];
|
||||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
const int64_t ne2 = gf->nodes[i]->ne[2];
|
||||||
|
|
||||||
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
||||||
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
||||||
|
@ -515,6 +515,7 @@ int llama_mtl_eval(
|
||||||
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
|
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
|
||||||
ggml_is_contiguous(gf->nodes[i]->src1) &&
|
ggml_is_contiguous(gf->nodes[i]->src1) &&
|
||||||
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
||||||
|
|
||||||
if (encoder != nil) {
|
if (encoder != nil) {
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
encoder = nil;
|
encoder = nil;
|
||||||
|
@ -649,7 +650,7 @@ int llama_mtl_eval(
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||||
const float eps = 1e-6f;
|
const float eps = 1e-6f;
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 256;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
@ -779,22 +780,22 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue