From ae36f009ef6f71162a47cbcd792428a0bf79841c Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 22 Oct 2023 14:25:25 -0300 Subject: [PATCH] Add new gpt_params_parse_ex function to hide arg-parse impl --- common/common.cpp | 16 +++++++++++++++- common/common.h | 2 ++ examples/embedding/embedding.cpp | 10 +--------- examples/infill/infill.cpp | 10 +--------- examples/llava/llava.cpp | 11 +---------- examples/main/main.cpp | 11 +---------- examples/parallel/parallel.cpp | 10 +--------- examples/perplexity/perplexity.cpp | 10 +--------- examples/save-load-state/save-load-state.cpp | 10 +--------- examples/speculative/speculative.cpp | 10 +--------- 10 files changed, 25 insertions(+), 75 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 378eb30f4..51b9b12cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -102,7 +102,21 @@ void process_escapes(std::string& input) { input.resize(output_idx); } -bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { +bool gpt_params_parse(int argc, char** argv, gpt_params& params) { + try { + if (!gpt_params_parse_ex(argc, argv, params)) { + gpt_print_usage(argc, argv, gpt_params()); + return false; + } + } + catch (const std::invalid_argument& ex) { + fprintf(stderr, ex.what()); + gpt_print_usage(argc, argv, gpt_params()); + return false; + } +} + +bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; const std::string arg_prefix = "--"; diff --git a/common/common.h b/common/common.h index 84523a4fb..e9fdfc212 100644 --- a/common/common.h +++ b/common/common.h @@ -110,6 +110,8 @@ struct gpt_params { std::string image = ""; // path to an image file }; +bool gpt_params_parse_ex(int argc, char** argv, gpt_params& params); + bool gpt_params_parse(int argc, char ** argv, gpt_params & params); void gpt_print_usage(int argc, char ** argv, const gpt_params & params); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 5dac3bd06..14075609e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -11,15 +11,7 @@ int main(int argc, char ** argv) { gpt_params params; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 74d41783e..6331335e3 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -107,15 +107,7 @@ int main(int argc, char ** argv) { llama_sampling_params & sparams = params.sparams; g_params = ¶ms; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index d298bb28e..f0974d5bc 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -17,16 +17,7 @@ int main(int argc, char ** argv) { gpt_params params; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - show_additional_info(argc, argv); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { show_additional_info(argc, argv); return 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 282a4898c..db5309afe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -105,18 +105,9 @@ int main(int argc, char ** argv) { gpt_params params; g_params = ¶ms; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { return 1; } - llama_sampling_params & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 18343b2fa..eb64adef8 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -102,15 +102,7 @@ int main(int argc, char ** argv) { gpt_params params; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (gpt_params_parse(argc, argv, params) == false) { return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index a4dbad95d..7d0038bd4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -682,15 +682,7 @@ int main(int argc, char ** argv) { gpt_params params; params.n_batch = 512; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 34896772d..38d05f4d3 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -11,15 +11,7 @@ int main(int argc, char ** argv) { params.prompt = "The quick brown fox"; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d66f88ce1..894321ce9 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -24,15 +24,7 @@ struct seq_draft { int main(int argc, char ** argv) { gpt_params params; - try { - if (!gpt_params_parse(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - return 1; - } - } - catch (const std::invalid_argument& ex) { - fprintf(stderr, ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + if (gpt_params_parse(argc, argv, params) == false) { return 1; }