diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f442f2b56..8d2005cf6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -608,6 +608,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -724,9 +725,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } + else if (arg == "-gqa" || arg == "--gqa") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_gqa = std::stoi(argv[i]); + } else if (arg == "--rope-freq-base") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } @@ -734,7 +745,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else if (arg == "--rope-freq-scale") { - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; }