add gqa parameter (Llama 2 70b support)
This commit is contained in:
parent
2d8b76a110
commit
a217151444
1 changed files with 15 additions and 0 deletions
|
@ -132,6 +132,7 @@ enum output_formats {CSV, JSON, MARKDOWN, SQL};
|
|||
|
||||
struct cmd_params {
|
||||
std::vector<std::string> model;
|
||||
std::vector<int> n_gqa;
|
||||
std::vector<int> n_prompt;
|
||||
std::vector<int> n_gen;
|
||||
std::vector<int> n_batch;
|
||||
|
@ -149,6 +150,7 @@ struct cmd_params {
|
|||
|
||||
static const cmd_params cmd_params_defaults = {
|
||||
/* model */ {"models/7B/ggml-model-q4_0.bin"},
|
||||
/* n_gqa */ {1},
|
||||
/* n_prompt */ {512},
|
||||
/* n_gen */ {128},
|
||||
/* n_batch */ {512},
|
||||
|
@ -170,6 +172,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
|||
fprintf(stdout, "options:\n");
|
||||
fprintf(stdout, " -h, --help\n");
|
||||
fprintf(stdout, " -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
|
||||
fprintf(stdout, " -gqa N, --gqa <n> (default: %s)\n", join(cmd_params_defaults.n_gqa, ",").c_str());
|
||||
fprintf(stdout, " -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
|
||||
fprintf(stdout, " -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
|
||||
fprintf(stdout, " -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
|
||||
|
@ -215,6 +218,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
}
|
||||
auto p = split<std::string>(argv[i], split_delim);
|
||||
params.model.insert(params.model.end(), p.begin(), p.end());
|
||||
} else if (arg == "-gqa" || arg == "--gqa") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = split<int>(argv[i], split_delim);
|
||||
params.n_gqa.insert(params.n_gqa.end(), p.begin(), p.end());
|
||||
} else if (arg == "-p" || arg == "--n-prompt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -337,6 +347,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
|
||||
// set defaults
|
||||
if (params.model.empty()) { params.model = cmd_params_defaults.model; }
|
||||
if (params.n_gqa.empty()) { params.n_gqa = cmd_params_defaults.n_gqa; }
|
||||
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
|
||||
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
|
||||
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
|
||||
|
@ -353,6 +364,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
|
||||
struct cmd_params_instance {
|
||||
std::string model;
|
||||
int n_gqa;
|
||||
int n_prompt;
|
||||
int n_gen;
|
||||
int n_batch;
|
||||
|
@ -366,6 +378,7 @@ struct cmd_params_instance {
|
|||
|
||||
llama_context_params to_llama_params() const {
|
||||
llama_context_params lparams = llama_context_default_params();
|
||||
lparams.n_gqa = n_gqa;
|
||||
lparams.n_ctx = n_prompt + n_gen;
|
||||
lparams.n_batch = n_batch;
|
||||
lparams.f16_kv = !f32_kv;
|
||||
|
@ -383,6 +396,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
|
|||
std::vector<cmd_params_instance> instances;
|
||||
|
||||
for (const auto & m : params.model)
|
||||
for (const auto & ng : params.n_gqa)
|
||||
for (const auto & nb : params.n_batch)
|
||||
for (const auto & fk : params.f32_kv)
|
||||
for (const auto & nl : params.n_gpu_layers)
|
||||
|
@ -393,6 +407,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
|
|||
for (const auto & nt : params.n_threads) {
|
||||
cmd_params_instance instance = {
|
||||
/* .model = */ m,
|
||||
/* .n_gqa = */ ng,
|
||||
/* .n_prompt = */ n_prompt,
|
||||
/* .n_gen = */ n_gen,
|
||||
/* .n_batch = */ nb,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue