diff --git a/Makefile b/Makefile index 332496cfc..9c61d3ec0 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,7 @@ BUILD_TARGETS = \ # Binaries only useful for tests TEST_TARGETS = \ + tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ tests/test-chat-template \ @@ -1505,6 +1506,11 @@ run-benchmark-matmult: llama-benchmark-matmult .PHONY: run-benchmark-matmult swift +tests/test-arg-parser: tests/test-arg-parser.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-llama-grammar: tests/test-llama-grammar.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/common/common.cpp b/common/common.cpp index 09e3a992c..ce9199c84 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -383,8 +383,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto const std::string arg_prefix = "--"; llama_sampling_params & sparams = params.sparams; - std::unordered_map arg_to_options; - for (const auto & opt : options) { + std::unordered_map arg_to_options; + for (auto & opt : options) { for (const auto & arg : opt.args) { arg_to_options[arg] = &opt; } @@ -404,8 +404,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto if (arg_to_options.find(arg) == arg_to_options.end()) { throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str())); } + auto opt = *arg_to_options[arg]; try { - auto opt = *arg_to_options[arg]; if (opt.handler_void) { opt.handler_void(); continue; @@ -431,7 +431,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vecto continue; } } catch (std::exception & e) { - throw std::invalid_argument(format("error: %s", e.what())); + throw std::invalid_argument(format( + "error while handling argument \"%s\": %s\n\n" + "usage:\n%s\n\nto show complete usage, run with -h", + arg.c_str(), e.what(), arg_to_options[arg]->to_string(false).c_str())); } } @@ -592,39 +595,49 @@ static std::vector break_str_into_lines(std::string input, size_t m return result; } -void gpt_params_print_usage(std::vector & options) { +std::string llama_arg::to_string(bool markdown) { + // params for printing to console const static int n_leading_spaces = 40; const static int n_char_per_line_help = 70; // TODO: detect this based on current console - - auto print_options = [](std::vector & options) { - std::string leading_spaces(n_leading_spaces, ' '); - for (const auto & opt : options) { - std::ostringstream ss; - for (const auto & arg : opt->args) { - if (&arg == &opt->args.front()) { - ss << (opt->args.size() == 1 ? arg : format("%-7s", (arg + ",").c_str())); - } else { - ss << arg << (&arg != &opt->args.back() ? ", " : ""); - } - } - if (!opt->value_hint.empty()) ss << " " << opt->value_hint; - if (ss.tellp() > n_leading_spaces - 3) { - // current line is too long, add new line - ss << "\n" << leading_spaces; - } else { - // padding between arg and help, same line - ss << std::string(leading_spaces.size() - ss.tellp(), ' '); - } - const auto help_lines = break_str_into_lines(opt->help, n_char_per_line_help); - for (const auto & line : help_lines) { - ss << (&line == &help_lines.front() ? "" : leading_spaces) << line << "\n"; - } - printf("%s", ss.str().c_str()); + std::string leading_spaces(n_leading_spaces, ' '); + + std::ostringstream ss; + if (markdown) ss << "| `"; + for (const auto & arg : args) { + if (arg == args.front()) { + ss << (args.size() == 1 ? arg : format("%-7s", (arg + ",").c_str())); + } else { + ss << arg << (arg != args.back() ? ", " : ""); + } + } + if (!value_hint.empty()) ss << " " << value_hint; + if (!markdown) { + if (ss.tellp() > n_leading_spaces - 3) { + // current line is too long, add new line + ss << "\n" << leading_spaces; + } else { + // padding between arg and help, same line + ss << std::string(leading_spaces.size() - ss.tellp(), ' '); + } + const auto help_lines = break_str_into_lines(help, n_char_per_line_help); + for (const auto & line : help_lines) { + ss << (&line == &help_lines.front() ? "" : leading_spaces) << line << "\n"; + } + } else { + ss << "` | " << help << " |"; + } + return ss.str(); +} + +void gpt_params_print_usage(std::vector & options) { + auto print_options = [](std::vector & options) { + for (llama_arg * opt : options) { + printf("%s", opt->to_string(false).c_str()); } }; - std::vector common_options; - std::vector specific_options; + std::vector common_options; + std::vector specific_options; for (auto & opt : options) { if (opt.in_example(LLAMA_EXAMPLE_COMMON)) { common_options.push_back(&opt); @@ -1688,7 +1701,7 @@ std::vector gpt_params_parser_init(gpt_params & params, llama_example } )); add_opt(llama_arg( - {"-sm", "--split-mode"}, "SPLIT_MODE", + {"-sm", "--split-mode"}, "{none,layer,row}", "how to split the model across multiple GPUs, one of:\n" "- none: use one GPU only\n" "- layer (default): split layers and KV across GPUs\n" diff --git a/common/common.h b/common/common.h index 27e908d7f..05211bf97 100644 --- a/common/common.h +++ b/common/common.h @@ -331,6 +331,8 @@ struct llama_arg { bool in_example(enum llama_example ex) { return examples.find(ex) != examples.end(); } + + std::string to_string(bool markdown); }; std::vector gpt_params_parser_init(gpt_params & params, llama_example ex); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0207e3a59..30e71cfd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -108,6 +108,7 @@ llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CU #llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf) # llama_target_and_test(test-double-float.cpp) # SLOW +llama_target_and_test(test-arg-parser.cpp) llama_target_and_test(test-quantize-fns.cpp) llama_target_and_test(test-quantize-perf.cpp) llama_target_and_test(test-sampling.cpp) diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp new file mode 100644 index 000000000..8b95a59d3 --- /dev/null +++ b/tests/test-arg-parser.cpp @@ -0,0 +1,67 @@ +#include +#include +#include + +#undef NDEBUG +#include + +#include "common.h" + +int main(void) { + gpt_params params; + + printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n"); + for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) { + try { + gpt_params_parser_init(params, (enum llama_example)ex); + } catch (std::exception & e) { + printf("%s\n", e.what()); + assert(false); + } + } + + auto list_str_to_char = [](std::vector & argv) -> std::vector { + std::vector res; + for (auto & arg : argv) { + res.push_back(const_cast(arg.data())); + } + return res; + }; + + std::vector argv; + auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); + + printf("test-arg-parser: test invalid usage\n\n"); + + argv = {"binary_name", "-m"}; + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + + argv = {"binary_name", "-ngl", "hello"}; + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + + argv = {"binary_name", "-sm", "hello"}; + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + + + printf("test-arg-parser: test valid usage\n\n"); + + argv = {"binary_name", "-m", "model_file.gguf"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.model == "model_file.gguf"); + + argv = {"binary_name", "-t", "1234"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.cpuparams.n_threads == 1234); + + argv = {"binary_name", "--verbose"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.verbosity == 1); + + argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"}; + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(params.model == "abc.gguf"); + assert(params.n_predict == 6789); + assert(params.n_batch == 9090); + + printf("test-arg-parser: all tests OK\n\n"); +}