From 96311e3248c9730b128b29999a4c8b7da0937ce9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 9 Sep 2024 20:23:45 +0200 Subject: [PATCH] refactor gpt_params_parse --- common/arg.cpp | 11 ++++++----- common/arg.h | 12 ++++++------ examples/batched-bench/batched-bench.cpp | 3 +-- examples/batched/batched.cpp | 3 +-- examples/cvector-generator/cvector-generator.cpp | 3 +-- examples/embedding/embedding.cpp | 3 +-- examples/eval-callback/eval-callback.cpp | 3 +-- examples/export-lora/export-lora.cpp | 3 +-- examples/gen-docs/gen-docs.cpp | 2 +- examples/gritlm/gritlm.cpp | 3 +-- examples/imatrix/imatrix.cpp | 3 +-- examples/infill/infill.cpp | 3 +-- examples/llava/llava-cli.cpp | 3 +-- examples/llava/minicpmv-cli.cpp | 3 +-- examples/lookahead/lookahead.cpp | 3 +-- examples/lookup/lookup-create.cpp | 3 +-- examples/lookup/lookup-stats.cpp | 3 +-- examples/lookup/lookup.cpp | 3 +-- examples/main/main.cpp | 3 +-- examples/parallel/parallel.cpp | 3 +-- examples/passkey/passkey.cpp | 3 +-- examples/perplexity/perplexity.cpp | 3 +-- examples/retrieval/retrieval.cpp | 3 +-- examples/save-load-state/save-load-state.cpp | 3 +-- examples/server/server.cpp | 3 +-- examples/simple/simple.cpp | 3 +-- examples/speculative/speculative.cpp | 3 +-- tests/test-arg-parser.cpp | 2 +- 28 files changed, 38 insertions(+), 61 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index cb207eed4..c5134be51 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -169,7 +169,7 @@ static void gpt_params_handle_model_default(gpt_params & params) { // CLI argument parsing functions // -static bool gpt_params_parse_ex(int argc, char ** argv, llama_arg_context & ctx_arg) { +static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx_arg) { std::string arg; const std::string arg_prefix = "--"; gpt_params & params = ctx_arg.params; @@ -290,7 +290,7 @@ static bool gpt_params_parse_ex(int argc, char ** argv, llama_arg_context & ctx_ return true; } -static void gpt_params_print_usage(llama_arg_context & ctx_arg) { +static void gpt_params_print_usage(gpt_params_context & ctx_arg) { auto print_options = [](std::vector & options) { for (llama_arg * opt : options) { printf("%s", opt->to_string().c_str()); @@ -319,7 +319,8 @@ static void gpt_params_print_usage(llama_arg_context & ctx_arg) { print_options(specific_options); } -bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg) { +bool gpt_params_parse(int argc, char ** argv, gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) { + auto ctx_arg = gpt_params_parser_init(params, ex, print_usage); const gpt_params params_org = ctx_arg.params; // the example can modify the default params try { @@ -343,8 +344,8 @@ bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg) { return true; } -llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) { - llama_arg_context ctx_arg(params); +gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **)) { + gpt_params_context ctx_arg(params); ctx_arg.print_usage = print_usage; ctx_arg.ex = ex; diff --git a/common/arg.h b/common/arg.h index c6b5f0cde..413de2c88 100644 --- a/common/arg.h +++ b/common/arg.h @@ -61,17 +61,17 @@ struct llama_arg { std::string to_string(); }; -struct llama_arg_context { +struct gpt_params_context { enum llama_example ex = LLAMA_EXAMPLE_COMMON; gpt_params & params; std::vector options; void(*print_usage)(int, char **) = nullptr; - llama_arg_context(gpt_params & params) : params(params) {} + gpt_params_context(gpt_params & params) : params(params) {} }; -// optionally, we can provide "print_usage" to print example usage -llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); - // parse input arguments from CLI // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) -bool gpt_params_parse(int argc, char ** argv, llama_arg_context & ctx_arg); +bool gpt_params_parse(int argc, char ** argv, gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); + +// function to be used by test-arg-parser +gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index b6a64e152..a91e7f4bd 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -38,8 +38,7 @@ static void print_usage(int, char ** argv) { int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_BENCH, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) { return 1; } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 32be23dfc..a593a5431 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -19,8 +19,7 @@ int main(int argc, char ** argv) { params.prompt = "Hello my name is"; params.n_predict = 32; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { return 1; } diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 4707e28a8..569b6c38f 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -389,8 +389,7 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) { return 1; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 20312a3fe..da7c79253 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -80,8 +80,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_EMBEDDING); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { return 1; } diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 9c8a5938b..bc7203143 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -145,8 +145,7 @@ int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 7c8c813ed..ff324926a 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -402,8 +402,7 @@ static void print_usage(int, char ** argv) { int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) { return 1; } diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp index b6d4725fd..640ee7184 100644 --- a/examples/gen-docs/gen-docs.cpp +++ b/examples/gen-docs/gen-docs.cpp @@ -10,7 +10,7 @@ static void export_md(std::string fname, llama_example ex) { std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc); gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, ex); + if (!gpt_params_parse(argc, argv, params, ex)) { file << "| Argument | Explanation |\n"; file << "| -------- | ----------- |\n"; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index d3bbf1e68..a2e0f895e 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -155,8 +155,7 @@ static std::string gritlm_instruction(const std::string & instruction) { int main(int argc, char * argv[]) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index a3771b750..032a90136 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -578,8 +578,7 @@ int main(int argc, char ** argv) { params.logits_all = true; params.verbosity = 1; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_IMATRIX, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) { return 1; } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index f5d1c239d..9a527e244 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -106,8 +106,7 @@ int main(int argc, char ** argv) { gpt_params params; g_params = ¶ms; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_INFILL); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_INFILL)) { return 1; } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index e8974a355..e9108a9bd 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -279,8 +279,7 @@ int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LLAVA, print_usage); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { return 1; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index b0a4b810f..3475bbce5 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -255,8 +255,7 @@ int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, show_additional_info); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, show_additional_info)) { return 1; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 93254cc6d..de8b792f2 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -38,8 +38,7 @@ struct ngram_container { int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index f632cd3b8..33287c02c 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -14,8 +14,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 2bda3a07e..f299d68a9 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -16,8 +16,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index c1f237cef..fff44a499 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -14,8 +14,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); - if (!gpt_params_parse(argc, argv, ctx_arg)) { + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 38db1ba56..b986a865a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -139,8 +139,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vector argv; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { printf("test-arg-parser: test invalid usage\n\n");