mtl : more threads for rms_norm + better timing
This commit is contained in:
parent
70c3387726
commit
b088e14a7e
2 changed files with 33 additions and 21 deletions
|
@ -46,10 +46,21 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const std::vector<int> tmp(n_batch, 1); // BOS
|
||||
|
||||
// warmup
|
||||
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
||||
|
||||
const int n_iter = 16;
|
||||
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
// the actual inference happens here
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
for (int i = 0; i < n_iter; ++i) {
|
||||
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
||||
}
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
printf("time: %.2f ms, %.2f ms/tok\n", (t1 - t0) / 1000.0, (t1 - t0) / 1000.0 / n_iter);
|
||||
}
|
||||
|
||||
llama_mtl_free(ctx_mtl);
|
||||
|
|
|
@ -515,6 +515,7 @@ int llama_mtl_eval(
|
|||
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
|
||||
ggml_is_contiguous(gf->nodes[i]->src1) &&
|
||||
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
||||
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
|
@ -649,7 +650,7 @@ int llama_mtl_eval(
|
|||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||
const float eps = 1e-6f;
|
||||
|
||||
const int nth = 32;
|
||||
const int nth = 256;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue