diff --git a/common/arg.cpp b/common/arg.cpp index dee72bb53..93fdc6c43 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -22,6 +22,11 @@ llama_arg & llama_arg::set_env(const char * env) { return *this; } +llama_arg & llama_arg::set_sparam() { + is_sparam = true; + return *this; +} + bool llama_arg::in_example(enum llama_example ex) { return examples.find(ex) != examples.end(); } @@ -289,19 +294,24 @@ static void gpt_params_print_usage(llama_arg_context & ctx_arg) { }; std::vector common_options; + std::vector sparam_options; std::vector specific_options; for (auto & opt : ctx_arg.options) { // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example - if (opt.in_example(ctx_arg.ex)) { + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { specific_options.push_back(&opt); } else { common_options.push_back(&opt); } } - printf("----- common options -----\n\n"); + printf("----- common params -----\n\n"); print_options(common_options); + printf("\n\n----- sampling params -----\n\n"); + print_options(sparam_options); // TODO: maybe convert enum llama_example to string - printf("\n\n----- example-specific options -----\n\n"); + printf("\n\n----- example-specific params -----\n\n"); print_options(specific_options); } @@ -412,13 +422,6 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.use_color = true; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"-s", "--seed"}, "SEED", - format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed), - [](gpt_params & params, const std::string & value) { - params.sparams.seed = std::stoul(value); - } - )); add_opt(llama_arg( {"-t", "--threads"}, "N", format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), @@ -641,7 +644,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, int value) { params.n_draft = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); add_opt(llama_arg( {"-ps", "--p-split"}, "N", format("speculative decoding split probability (default: %.1f)", (double)params.p_split), @@ -655,14 +658,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, const std::string & value) { params.lookup_cache_static = value; } - )); + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); add_opt(llama_arg( {"-lcd", "--lookup-cache-dynamic"}, "FNAME", "path to dynamic lookup cache to use for lookup decoding (updated by generation)", [](gpt_params & params, const std::string & value) { params.lookup_cache_dynamic = value; } - )); + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); add_opt(llama_arg( {"-c", "--ctx-size"}, "N", format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), @@ -704,7 +707,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, int value) { params.n_chunks = value; } - )); + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); add_opt(llama_arg( {"-fa", "--flash-attn"}, format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), @@ -747,7 +750,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, } params.in_files.push_back(value); } - )); + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(llama_arg( {"-bf", "--binary-file"}, "FNAME", "binary file containing the prompt (default: none)", @@ -902,28 +905,35 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, const auto sampler_names = string_split(value, ';'); params.sparams.samplers = gpt_sampler_types_from_names(sampler_names, true); } - )); + ).set_sparam()); + add_opt(llama_arg( + {"-s", "--seed"}, "SEED", + format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed), + [](gpt_params & params, const std::string & value) { + params.sparams.seed = std::stoul(value); + } + ).set_sparam()); add_opt(llama_arg( {"--sampling-seq"}, "SEQUENCE", format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), [](gpt_params & params, const std::string & value) { params.sparams.samplers = gpt_sampler_types_from_chars(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--ignore-eos"}, "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", [](gpt_params & params) { params.sparams.ignore_eos = true; } - )); + ).set_sparam()); add_opt(llama_arg( {"--penalize-nl"}, format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"), [](gpt_params & params) { params.sparams.penalize_nl = true; } - )); + ).set_sparam()); add_opt(llama_arg( {"--temp"}, "N", format("temperature (default: %.1f)", (double)params.sparams.temp), @@ -931,42 +941,42 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.sparams.temp = std::stof(value); params.sparams.temp = std::max(params.sparams.temp, 0.0f); } - )); + ).set_sparam()); add_opt(llama_arg( {"--top-k"}, "N", format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), [](gpt_params & params, int value) { params.sparams.top_k = value; } - )); + ).set_sparam()); add_opt(llama_arg( {"--top-p"}, "N", format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p), [](gpt_params & params, const std::string & value) { params.sparams.top_p = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--min-p"}, "N", format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p), [](gpt_params & params, const std::string & value) { params.sparams.min_p = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--tfs"}, "N", format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z), [](gpt_params & params, const std::string & value) { params.sparams.tfs_z = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--typical"}, "N", format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), [](gpt_params & params, const std::string & value) { params.sparams.typ_p = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--repeat-last-n"}, "N", format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n), @@ -974,42 +984,42 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.sparams.penalty_last_n = value; params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n); } - )); + ).set_sparam()); add_opt(llama_arg( {"--repeat-penalty"}, "N", format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat), [](gpt_params & params, const std::string & value) { params.sparams.penalty_repeat = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--presence-penalty"}, "N", format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present), [](gpt_params & params, const std::string & value) { params.sparams.penalty_present = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--frequency-penalty"}, "N", format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq), [](gpt_params & params, const std::string & value) { params.sparams.penalty_freq = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--dynatemp-range"}, "N", format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), [](gpt_params & params, const std::string & value) { params.sparams.dynatemp_range = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--dynatemp-exp"}, "N", format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent), [](gpt_params & params, const std::string & value) { params.sparams.dynatemp_exponent = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--mirostat"}, "N", format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" @@ -1017,21 +1027,21 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, int value) { params.sparams.mirostat = value; } - )); + ).set_sparam()); add_opt(llama_arg( {"--mirostat-lr"}, "N", format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta), [](gpt_params & params, const std::string & value) { params.sparams.mirostat_eta = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"--mirostat-ent"}, "N", format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau), [](gpt_params & params, const std::string & value) { params.sparams.mirostat_tau = std::stof(value); } - )); + ).set_sparam()); add_opt(llama_arg( {"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" @@ -1053,14 +1063,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, throw std::invalid_argument("invalid input format"); } } - )); + ).set_sparam()); add_opt(llama_arg( {"--grammar"}, "GRAMMAR", format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()), [](gpt_params & params, const std::string & value) { params.sparams.grammar = value; } - )); + ).set_sparam()); add_opt(llama_arg( {"--grammar-file"}, "FNAME", "file to read grammar from", @@ -1075,14 +1085,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, std::back_inserter(params.sparams.grammar) ); } - )); + ).set_sparam()); add_opt(llama_arg( {"-j", "--json-schema"}, "SCHEMA", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", [](gpt_params & params, const std::string & value) { params.sparams.grammar = json_schema_to_grammar(json::parse(value)); } - )); + ).set_sparam()); add_opt(llama_arg( {"--pooling"}, "{none,mean,cls,last}", "pooling type for embeddings, use model default if unspecified", @@ -1310,21 +1320,21 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, int value) { params.n_sequences = value; } - )); + ).set_examples({LLAMA_EXAMPLE_PARALLEL})); add_opt(llama_arg( {"-cb", "--cont-batching"}, format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"), [](gpt_params & params) { params.cont_batching = true; } - ).set_env("LLAMA_ARG_CONT_BATCHING")); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CONT_BATCHING")); add_opt(llama_arg( {"-nocb", "--no-cont-batching"}, "disable continuous batching", [](gpt_params & params) { params.cont_batching = false; } - ).set_env("LLAMA_ARG_NO_CONT_BATCHING")); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING")); add_opt(llama_arg( {"--mmproj"}, "FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md", @@ -1487,6 +1497,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, const std::string & value) { params.lora_adapters.push_back({ std::string(value), 1.0 }); } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); add_opt(llama_arg( {"--lora-scaled"}, "FNAME", "SCALE", @@ -1494,6 +1505,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, const std::string & fname, const std::string & scale) { params.lora_adapters.push_back({ fname, std::stof(scale) }); } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); add_opt(llama_arg( {"--control-vector"}, "FNAME", diff --git a/common/arg.h b/common/arg.h index 728c326d6..eb9530dcb 100644 --- a/common/arg.h +++ b/common/arg.h @@ -27,6 +27,8 @@ enum llama_example { LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_LOOKUP, + LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_COUNT, }; @@ -38,6 +40,7 @@ struct llama_arg { const char * value_hint_2 = nullptr; // for second arg value const char * env = nullptr; std::string help; + bool is_sparam = false; // is current arg a sampling param? void (*handler_void) (gpt_params & params) = nullptr; void (*handler_string) (gpt_params & params, const std::string &) = nullptr; void (*handler_str_str)(gpt_params & params, const std::string &, const std::string &) = nullptr; @@ -74,6 +77,7 @@ struct llama_arg { llama_arg & set_examples(std::initializer_list examples); llama_arg & set_env(const char * env); + llama_arg & set_sparam(); bool in_example(enum llama_example ex); bool get_value_from_env(std::string & output); bool has_value_from_env(); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index f31108caf..611a216d7 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -13,7 +13,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); + auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, ctx_arg)) { return 1; } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 3d399c657..099a29d24 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -15,7 +15,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_SPECULATIVE); + auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, ctx_arg)) { return 1; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 1d8bac436..fb07051d2 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -12,7 +12,7 @@ int main(int argc, char ** argv){ gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); + auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP); if (!gpt_params_parse(argc, argv, ctx_arg)) { return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 0dc45a453..ecda353f1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -100,7 +100,7 @@ int main(int argc, char ** argv) { gpt_params params; - auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); + auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_PARALLEL); if (!gpt_params_parse(argc, argv, ctx_arg)) { return 1; } diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index c3a338d34..75b3451a9 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -17,7 +17,7 @@ int main(void) { auto ctx_arg = gpt_params_parser_init(params, (enum llama_example)ex); std::unordered_set seen_args; std::unordered_set seen_env_vars; - for (const auto & opt : options) { + for (const auto & opt : ctx_arg.options) { // check for args duplications for (const auto & arg : opt.args) { if (seen_args.find(arg) == seen_args.end()) { @@ -57,31 +57,31 @@ int main(void) { printf("test-arg-parser: test invalid usage\n\n"); argv = {"binary_name", "-m"}; - assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); argv = {"binary_name", "-ngl", "hello"}; - assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); argv = {"binary_name", "-sm", "hello"}; - assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); printf("test-arg-parser: test valid usage\n\n"); argv = {"binary_name", "-m", "model_file.gguf"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.model == "model_file.gguf"); argv = {"binary_name", "-t", "1234"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.cpuparams.n_threads == 1234); argv = {"binary_name", "--verbose"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.verbosity == 1); argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.model == "abc.gguf"); assert(params.n_predict == 6789); assert(params.n_batch == 9090); @@ -94,12 +94,12 @@ int main(void) { setenv("LLAMA_ARG_THREADS", "blah", true); argv = {"binary_name"}; - assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); setenv("LLAMA_ARG_MODEL", "blah.gguf", true); setenv("LLAMA_ARG_THREADS", "1010", true); argv = {"binary_name"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.model == "blah.gguf"); assert(params.cpuparams.n_threads == 1010); @@ -109,7 +109,7 @@ int main(void) { setenv("LLAMA_ARG_MODEL", "blah.gguf", true); setenv("LLAMA_ARG_THREADS", "1010", true); argv = {"binary_name", "-m", "overwritten.gguf"}; - assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); + assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg)); assert(params.model == "overwritten.gguf"); assert(params.cpuparams.n_threads == 1010); #endif // _WIN32