add gqa parameter (Llama 2 70b support)

This commit is contained in:
Colin Calvert 2023-08-18 15:41:51 -05:00
parent 2d8b76a110
commit a217151444

View file

@ -132,6 +132,7 @@ enum output_formats {CSV, JSON, MARKDOWN, SQL};
struct cmd_params {
std::vector<std::string> model;
std::vector<int> n_gqa;
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<int> n_batch;
@ -149,6 +150,7 @@ struct cmd_params {
static const cmd_params cmd_params_defaults = {
/* model */ {"models/7B/ggml-model-q4_0.bin"},
/* n_gqa */ {1},
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_batch */ {512},
@ -170,6 +172,7 @@ static void print_usage(int /* argc */, char ** argv) {
fprintf(stdout, "options:\n");
fprintf(stdout, " -h, --help\n");
fprintf(stdout, " -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
fprintf(stdout, " -gqa N, --gqa <n> (default: %s)\n", join(cmd_params_defaults.n_gqa, ",").c_str());
fprintf(stdout, " -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
fprintf(stdout, " -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
fprintf(stdout, " -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
@ -215,6 +218,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<std::string>(argv[i], split_delim);
params.model.insert(params.model.end(), p.begin(), p.end());
} else if (arg == "-gqa" || arg == "--gqa") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<int>(argv[i], split_delim);
params.n_gqa.insert(params.n_gqa.end(), p.begin(), p.end());
} else if (arg == "-p" || arg == "--n-prompt") {
if (++i >= argc) {
invalid_param = true;
@ -337,6 +347,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
// set defaults
if (params.model.empty()) { params.model = cmd_params_defaults.model; }
if (params.n_gqa.empty()) { params.n_gqa = cmd_params_defaults.n_gqa; }
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
@ -353,6 +364,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
struct cmd_params_instance {
std::string model;
int n_gqa;
int n_prompt;
int n_gen;
int n_batch;
@ -366,6 +378,7 @@ struct cmd_params_instance {
llama_context_params to_llama_params() const {
llama_context_params lparams = llama_context_default_params();
lparams.n_gqa = n_gqa;
lparams.n_ctx = n_prompt + n_gen;
lparams.n_batch = n_batch;
lparams.f16_kv = !f32_kv;
@ -383,6 +396,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
std::vector<cmd_params_instance> instances;
for (const auto & m : params.model)
for (const auto & ng : params.n_gqa)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & nl : params.n_gpu_layers)
@ -393,6 +407,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
for (const auto & nt : params.n_threads) {
cmd_params_instance instance = {
/* .model = */ m,
/* .n_gqa = */ ng,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,