diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a885573ea..551557e8f 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -936,6 +936,7 @@ struct test { int n_prompt; int n_gen; std::string test_time; + std::string test_label; std::vector samples_ns; test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { @@ -969,6 +970,26 @@ struct test { std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); test_time = buf; + // prepare test label for printing + switch (test_kind) { + case TEST_KIND_PP: + snprintf(buf, sizeof(buf), "pp%d", n_prompt); + break; + case TEST_KIND_TG: + snprintf(buf, sizeof(buf), "tg%d", n_gen); + break; + case TEST_KIND_PG: + snprintf(buf, sizeof(buf), "pp%d+tg%d", n_prompt, n_gen); + break; + case TEST_KIND_GP: + snprintf(buf, sizeof(buf), "tg%d@pp%d", n_gen, n_prompt); + break; + default: + snprintf(buf, sizeof(buf), "unknown"); + break; + } + test_label = buf; + (void) ctx; } @@ -1007,7 +1028,7 @@ struct test { "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", - "avg_ts", "stddev_ts", + "avg_ts", "stddev_ts", "test", }; return fields; } @@ -1078,7 +1099,8 @@ struct test { std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), - std::to_string(stdev_ts()) }; + std::to_string(stdev_ts()), + test_label }; return values; } @@ -1390,24 +1412,7 @@ struct markdown_printer : public printer { } else if (field == "backend") { value = test::get_backend(); } else if (field == "test") { - switch (t.test_kind) { - case TEST_KIND_PP: - snprintf(buf, sizeof(buf), "pp%d", t.n_prompt); - break; - case TEST_KIND_TG: - snprintf(buf, sizeof(buf), "tg%d", t.n_gen); - break; - case TEST_KIND_PG: - snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen); - break; - case TEST_KIND_GP: - snprintf(buf, sizeof(buf), "tg%d@pp%d", t.n_gen, t.n_prompt); - break; - default: - assert(false); - exit(1); - } - value = buf; + value = t.test_label; } else if (field == "t/s") { snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); value = buf;