mtl : more threads for rms_norm + better timing

This commit is contained in:
Georgi Gerganov 2023-06-02 19:26:58 +03:00
parent 70c3387726
commit b088e14a7e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 33 additions and 21 deletions

View file

@ -46,10 +46,21 @@ int main(int argc, char ** argv) {
const std::vector<int> tmp(n_batch, 1); // BOS
// warmup
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
const int n_iter = 16;
const int64_t t0 = ggml_time_us();
// the actual inference happens here
for (int i = 0; i < 10; ++i) {
for (int i = 0; i < n_iter; ++i) {
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
}
const int64_t t1 = ggml_time_us();
printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
}
llama_mtl_free(ctx_mtl);

View file

@ -515,6 +515,7 @@ int llama_mtl_eval(
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
ggml_is_contiguous(gf->nodes[i]->src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
@ -649,7 +650,7 @@ int llama_mtl_eval(
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const float eps = 1e-6f;
const int nth = 32;
const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];