refactor gpt_params_parse

This commit is contained in:
Xuan Son Nguyen 2024-09-09 20:23:45 +02:00
parent cf2a874142
commit 96311e3248
28 changed files with 38 additions and 61 deletions

View file

@ -169,7 +169,7 @@ static void gpt_params_handle_model_default(gpt_params & params) {
// CLI argument parsing functions // CLI argument parsing functions
// //
static bool gpt_params_parse_ex(int argc, char ** argv, llama_arg_context & ctx_arg) { static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx_arg) {
std::string arg; std::string arg;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
gpt_params & params = ctx_arg.params; gpt_params & params = ctx_arg.params;
@ -290,7 +290,7 @@ static bool gpt_params_parse_ex(int argc, char ** argv, llama_arg_context & ctx_
return true; return true;
} }
static void gpt_params_print_usage(llama_arg_context & ctx_arg) { static void gpt_params_print_usage(gpt_params_context & ctx_arg) {
auto print_options = [](std::vector<llama_arg *> & options) { auto print_options = [](std::vector<llama_arg *> & options) {
for (llama_arg * opt : options) { for (llama_arg * opt : options) {
printf("%s", opt->to_string().c_str()); printf("%s", opt->to_string().c_str());
@ -319,7 +319,8 @@ static void gpt_params_print_usage(llama_arg_context & ctx_arg) {
print_options(specific_options); print_options(specific_options);
} }
bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg) { bool gpt_params_parse(int argc, char ** argv, gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) {
auto ctx_arg = gpt_params_parser_init(params, ex, print_usage);
const gpt_params params_org = ctx_arg.params; // the example can modify the default params const gpt_params params_org = ctx_arg.params; // the example can modify the default params
try { try {
@ -343,8 +344,8 @@ bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg) {
return true; return true;
} }
llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) { gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) {
llama_arg_context ctx_arg(params); gpt_params_context ctx_arg(params);
ctx_arg.print_usage = print_usage; ctx_arg.print_usage = print_usage;
ctx_arg.ex = ex; ctx_arg.ex = ex;

View file

@ -61,17 +61,17 @@ struct llama_arg {
std::string to_string(); std::string to_string();
}; };
struct llama_arg_context { struct gpt_params_context {
enum llama_example ex = LLAMA_EXAMPLE_COMMON; enum llama_example ex = LLAMA_EXAMPLE_COMMON;
gpt_params & params; gpt_params & params;
std::vector<llama_arg> options; std::vector<llama_arg> options;
void(*print_usage)(int, char **) = nullptr; void(*print_usage)(int, char **) = nullptr;
llama_arg_context(gpt_params & params) : params(params) {} gpt_params_context(gpt_params & params) : params(params) {}
}; };
// optionally, we can provide "print_usage" to print example usage
llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
// parse input arguments from CLI // parse input arguments from CLI
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg); bool gpt_params_parse(int argc, char ** argv, gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
// function to be used by test-arg-parser
gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

View file

@ -38,8 +38,7 @@ static void print_usage(int, char ** argv) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_BENCH, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -19,8 +19,7 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
params.n_predict = 32; params.n_predict = 32;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -389,8 +389,7 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -80,8 +80,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_EMBEDDING); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -145,8 +145,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -402,8 +402,7 @@ static void print_usage(int, char ** argv) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -10,7 +10,7 @@ static void export_md(std::string fname, llama_example ex) {
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc); std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, ex); if (!gpt_params_parse(argc, argv, params, ex)) {
file << "| Argument | Explanation |\n"; file << "| Argument | Explanation |\n";
file << "| -------- | ----------- |\n"; file << "| -------- | ----------- |\n";

View file

@ -155,8 +155,7 @@ static std::string gritlm_instruction(const std::string & instruction) {
int main(int argc, char * argv[]) { int main(int argc, char * argv[]) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -578,8 +578,7 @@ int main(int argc, char ** argv) {
params.logits_all = true; params.logits_all = true;
params.verbosity = 1; params.verbosity = 1;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_IMATRIX, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -106,8 +106,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
g_params = &params; g_params = &params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_INFILL); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_INFILL)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -279,8 +279,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LLAVA, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -255,8 +255,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, show_additional_info); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, show_additional_info)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -38,8 +38,7 @@ struct ngram_container {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -14,8 +14,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -16,8 +16,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -14,8 +14,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -139,8 +139,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector<l
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
g_params = &params; g_params = &params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_MAIN, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -102,8 +102,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_PARALLEL); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -20,8 +20,7 @@ int main(int argc, char ** argv) {
params.n_keep = 32; params.n_keep = 32;
params.i_pos = -1; params.i_pos = -1;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_PASSKEY, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -1968,8 +1968,7 @@ int main(int argc, char ** argv) {
params.n_ctx = 512; params.n_ctx = 512;
params.logits_all = true; params.logits_all = true;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_PERPLEXITY); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -112,8 +112,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_RETRIEVAL, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -11,8 +11,7 @@ int main(int argc, char ** argv) {
params.prompt = "The quick brown fox"; params.prompt = "The quick brown fox";
params.sparams.seed = 1234; params.sparams.seed = 1234;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -2425,8 +2425,7 @@ int main(int argc, char ** argv) {
// own arguments required by this example // own arguments required by this example
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_SERVER); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -19,8 +19,7 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
params.n_predict = 32; params.n_predict = 32;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -29,8 +29,7 @@ struct seq_draft {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_SPECULATIVE); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -53,7 +53,7 @@ int main(void) {
}; };
std::vector<std::string> argv; std::vector<std::string> argv;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
printf("test-arg-parser: test invalid usage\n\n"); printf("test-arg-parser: test invalid usage\n\n");