From a21715144416df1bf92c6134cb9181bc7bc81a85 Mon Sep 17 00:00:00 2001 From: Colin Calvert Date: Fri, 18 Aug 2023 15:41:51 -0500 Subject: [PATCH] add gqa parameter (Llama 2 70b support) --- examples/llama-bench/llama-bench.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 266c8eab3..ecb573a6d 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -132,6 +132,7 @@ enum output_formats {CSV, JSON, MARKDOWN, SQL}; struct cmd_params { std::vector model; + std::vector n_gqa; std::vector n_prompt; std::vector n_gen; std::vector n_batch; @@ -149,6 +150,7 @@ struct cmd_params { static const cmd_params cmd_params_defaults = { /* model */ {"models/7B/ggml-model-q4_0.bin"}, + /* n_gqa */ {1}, /* n_prompt */ {512}, /* n_gen */ {128}, /* n_batch */ {512}, @@ -170,6 +172,7 @@ static void print_usage(int /* argc */, char ** argv) { fprintf(stdout, "options:\n"); fprintf(stdout, " -h, --help\n"); fprintf(stdout, " -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); + fprintf(stdout, " -gqa N, --gqa (default: %s)\n", join(cmd_params_defaults.n_gqa, ",").c_str()); fprintf(stdout, " -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); fprintf(stdout, " -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); fprintf(stdout, " -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); @@ -215,6 +218,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.model.insert(params.model.end(), p.begin(), p.end()); + } else if (arg == "-gqa" || arg == "--gqa") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.n_gqa.insert(params.n_gqa.end(), p.begin(), p.end()); } else if (arg == "-p" || arg == "--n-prompt") { if (++i >= argc) { invalid_param = true; @@ -337,6 +347,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { // set defaults if (params.model.empty()) { params.model = cmd_params_defaults.model; } + if (params.n_gqa.empty()) { params.n_gqa = cmd_params_defaults.n_gqa; } if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } @@ -353,6 +364,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { struct cmd_params_instance { std::string model; + int n_gqa; int n_prompt; int n_gen; int n_batch; @@ -366,6 +378,7 @@ struct cmd_params_instance { llama_context_params to_llama_params() const { llama_context_params lparams = llama_context_default_params(); + lparams.n_gqa = n_gqa; lparams.n_ctx = n_prompt + n_gen; lparams.n_batch = n_batch; lparams.f16_kv = !f32_kv; @@ -383,6 +396,7 @@ static std::vector get_cmd_params_instances_int(const cmd_p std::vector instances; for (const auto & m : params.model) + for (const auto & ng : params.n_gqa) for (const auto & nb : params.n_batch) for (const auto & fk : params.f32_kv) for (const auto & nl : params.n_gpu_layers) @@ -393,6 +407,7 @@ static std::vector get_cmd_params_instances_int(const cmd_p for (const auto & nt : params.n_threads) { cmd_params_instance instance = { /* .model = */ m, + /* .n_gqa = */ ng, /* .n_prompt = */ n_prompt, /* .n_gen = */ n_gen, /* .n_batch = */ nb,