diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2e845a0ff..883e0e893 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -90,8 +90,8 @@ struct cmd_params { static cmd_params cmd_params_defaults = { /* model */ {"models/7B/ggml-model-q4_0.bin"}, - /* n_prompt */ {0, 512}, - /* n_gen */ {0, 128}, + /* n_prompt */ {512}, + /* n_gen */ {128}, /* n_batch */ {512}, /* f32_kv */ {false}, /* n_threads */ {get_num_physical_cores()}, @@ -319,7 +319,7 @@ struct cmd_params_instance { } }; -static std::vector get_cmd_params_instances(const cmd_params & params) { +static std::vector get_cmd_params_instances_int(const cmd_params & params, int n_gen, int n_prompt) { std::vector instances; for (const auto & m : params.model) @@ -330,15 +330,12 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & mmq : params.mul_mat_q) for (const auto & lv : params.low_vram) for (const auto & ts : params.tensor_split) - for (const auto & nt : params.n_threads) - for (const auto & ng : params.n_gen) - for (const auto & np : params.n_prompt) { - if (np == 0 && ng == 0) continue; // no prompt and no generation - + for (const auto & nt : params.n_threads) { cmd_params_instance instance; + instance.model = m; - instance.n_prompt = np; - instance.n_gen = ng; + instance.n_prompt = n_prompt; + instance.n_gen = n_gen; instance.n_batch = nb; instance.f32_kv = fk; instance.n_gpu_layers = nl; @@ -350,6 +347,27 @@ static std::vector get_cmd_params_instances(const cmd_param instances.push_back(instance); } + return instances; +} + +static std::vector get_cmd_params_instances(const cmd_params & params) { + std::vector instances; + + for (const auto & n_prompt : params.n_prompt) { + if (n_prompt == 0) { + continue; + } + auto instances_prompt = get_cmd_params_instances_int(params, 0, n_prompt); + instances.insert(instances.end(), instances_prompt.begin(), instances_prompt.end()); + } + + for (const auto & n_gen : params.n_gen) { + if (n_gen == 0) { + continue; + } + auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0); + instances.insert(instances.end(), instances_gen.begin(), instances_gen.end()); + } return instances; } @@ -633,12 +651,7 @@ struct markdown_printer : public printer { if (params.low_vram.size() > 1) { fields.push_back("low_vram"); } - if (params.n_prompt.size() > 1 || (params.n_prompt.size() == 1 && params.n_prompt.at(0) != 0)) { - fields.push_back("n_prompt"); - } - if (params.n_gen.size() > 1 || (params.n_gen.size() == 1 && params.n_gen.at(0) != 0)) { - fields.push_back("n_gen"); - } + fields.push_back("test"); fields.push_back("t/s"); fprintf(fout, "|"); @@ -673,6 +686,17 @@ struct markdown_printer : public printer { value = t.mparams.type; } else if (field == "backend") { value = backend_params::get_backend(); + } else if (field == "test") { + char buf[128]; + if (t.bparams.n_prompt > 0) { + snprintf(buf, sizeof(buf), "pp %d", t.bparams.n_prompt); + } else if (t.bparams.n_gen > 0) { + snprintf(buf, sizeof(buf), "tg %d", t.bparams.n_gen); + } else { + assert(false); + exit(1); + } + value = buf; } else if (field == "t/s") { char buf[128]; snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.tsamples.avg_ts(n_tokens), t.tsamples.stddev_ts(n_tokens));