better categorize args

This commit is contained in:
Xuan Son Nguyen 2024-09-09 15:44:48 +02:00
parent 5d399f5689
commit 9ea1e93591
7 changed files with 72 additions and 56 deletions

View file

@ -22,6 +22,11 @@ llama_arg & llama_arg::set_env(const char * env) {
return *this; return *this;
} }
llama_arg & llama_arg::set_sparam() {
is_sparam = true;
return *this;
}
bool llama_arg::in_example(enum llama_example ex) { bool llama_arg::in_example(enum llama_example ex) {
return examples.find(ex) != examples.end(); return examples.find(ex) != examples.end();
} }
@ -289,19 +294,24 @@ static void gpt_params_print_usage(llama_arg_context & ctx_arg) {
}; };
std::vector<llama_arg *> common_options; std::vector<llama_arg *> common_options;
std::vector<llama_arg *> sparam_options;
std::vector<llama_arg *> specific_options; std::vector<llama_arg *> specific_options;
for (auto & opt : ctx_arg.options) { for (auto & opt : ctx_arg.options) {
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
if (opt.in_example(ctx_arg.ex)) { if (opt.is_sparam) {
sparam_options.push_back(&opt);
} else if (opt.in_example(ctx_arg.ex)) {
specific_options.push_back(&opt); specific_options.push_back(&opt);
} else { } else {
common_options.push_back(&opt); common_options.push_back(&opt);
} }
} }
printf("----- common options -----\n\n"); printf("----- common params -----\n\n");
print_options(common_options); print_options(common_options);
printf("\n\n----- sampling params -----\n\n");
print_options(sparam_options);
// TODO: maybe convert enum llama_example to string // TODO: maybe convert enum llama_example to string
printf("\n\n----- example-specific options -----\n\n"); printf("\n\n----- example-specific params -----\n\n");
print_options(specific_options); print_options(specific_options);
} }
@ -412,13 +422,6 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.use_color = true; params.use_color = true;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
add_opt(llama_arg(
{"-s", "--seed"}, "SEED",
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
[](gpt_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
}
));
add_opt(llama_arg( add_opt(llama_arg(
{"-t", "--threads"}, "N", {"-t", "--threads"}, "N",
format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
@ -641,7 +644,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.n_draft = value; params.n_draft = value;
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
add_opt(llama_arg( add_opt(llama_arg(
{"-ps", "--p-split"}, "N", {"-ps", "--p-split"}, "N",
format("speculative decoding split probability (default: %.1f)", (double)params.p_split), format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
@ -655,14 +658,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.lookup_cache_static = value; params.lookup_cache_static = value;
} }
)); ).set_examples({LLAMA_EXAMPLE_LOOKUP}));
add_opt(llama_arg( add_opt(llama_arg(
{"-lcd", "--lookup-cache-dynamic"}, "FNAME", {"-lcd", "--lookup-cache-dynamic"}, "FNAME",
"path to dynamic lookup cache to use for lookup decoding (updated by generation)", "path to dynamic lookup cache to use for lookup decoding (updated by generation)",
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.lookup_cache_dynamic = value; params.lookup_cache_dynamic = value;
} }
)); ).set_examples({LLAMA_EXAMPLE_LOOKUP}));
add_opt(llama_arg( add_opt(llama_arg(
{"-c", "--ctx-size"}, "N", {"-c", "--ctx-size"}, "N",
format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
@ -704,7 +707,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.n_chunks = value; params.n_chunks = value;
} }
)); ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(llama_arg( add_opt(llama_arg(
{"-fa", "--flash-attn"}, {"-fa", "--flash-attn"},
format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
@ -747,7 +750,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
} }
params.in_files.push_back(value); params.in_files.push_back(value);
} }
)); ).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(llama_arg( add_opt(llama_arg(
{"-bf", "--binary-file"}, "FNAME", {"-bf", "--binary-file"}, "FNAME",
"binary file containing the prompt (default: none)", "binary file containing the prompt (default: none)",
@ -902,28 +905,35 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
const auto sampler_names = string_split(value, ';'); const auto sampler_names = string_split(value, ';');
params.sparams.samplers = gpt_sampler_types_from_names(sampler_names, true); params.sparams.samplers = gpt_sampler_types_from_names(sampler_names, true);
} }
)); ).set_sparam());
add_opt(llama_arg(
{"-s", "--seed"}, "SEED",
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
[](gpt_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
}
).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--sampling-seq"}, "SEQUENCE", {"--sampling-seq"}, "SEQUENCE",
format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.samplers = gpt_sampler_types_from_chars(value); params.sparams.samplers = gpt_sampler_types_from_chars(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--ignore-eos"}, {"--ignore-eos"},
"ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)",
[](gpt_params & params) { [](gpt_params & params) {
params.sparams.ignore_eos = true; params.sparams.ignore_eos = true;
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--penalize-nl"}, {"--penalize-nl"},
format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"), format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"),
[](gpt_params & params) { [](gpt_params & params) {
params.sparams.penalize_nl = true; params.sparams.penalize_nl = true;
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--temp"}, "N", {"--temp"}, "N",
format("temperature (default: %.1f)", (double)params.sparams.temp), format("temperature (default: %.1f)", (double)params.sparams.temp),
@ -931,42 +941,42 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.sparams.temp = std::stof(value); params.sparams.temp = std::stof(value);
params.sparams.temp = std::max(params.sparams.temp, 0.0f); params.sparams.temp = std::max(params.sparams.temp, 0.0f);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--top-k"}, "N", {"--top-k"}, "N",
format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.sparams.top_k = value; params.sparams.top_k = value;
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--top-p"}, "N", {"--top-p"}, "N",
format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p), format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.top_p = std::stof(value); params.sparams.top_p = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--min-p"}, "N", {"--min-p"}, "N",
format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p), format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.min_p = std::stof(value); params.sparams.min_p = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--tfs"}, "N", {"--tfs"}, "N",
format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z), format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.tfs_z = std::stof(value); params.sparams.tfs_z = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--typical"}, "N", {"--typical"}, "N",
format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.typ_p = std::stof(value); params.sparams.typ_p = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--repeat-last-n"}, "N", {"--repeat-last-n"}, "N",
format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n), format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n),
@ -974,42 +984,42 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.sparams.penalty_last_n = value; params.sparams.penalty_last_n = value;
params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n); params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--repeat-penalty"}, "N", {"--repeat-penalty"}, "N",
format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat), format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.penalty_repeat = std::stof(value); params.sparams.penalty_repeat = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--presence-penalty"}, "N", {"--presence-penalty"}, "N",
format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present), format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.penalty_present = std::stof(value); params.sparams.penalty_present = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--frequency-penalty"}, "N", {"--frequency-penalty"}, "N",
format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq), format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.penalty_freq = std::stof(value); params.sparams.penalty_freq = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--dynatemp-range"}, "N", {"--dynatemp-range"}, "N",
format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.dynatemp_range = std::stof(value); params.sparams.dynatemp_range = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--dynatemp-exp"}, "N", {"--dynatemp-exp"}, "N",
format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent), format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.dynatemp_exponent = std::stof(value); params.sparams.dynatemp_exponent = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--mirostat"}, "N", {"--mirostat"}, "N",
format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
@ -1017,21 +1027,21 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.sparams.mirostat = value; params.sparams.mirostat = value;
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--mirostat-lr"}, "N", {"--mirostat-lr"}, "N",
format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta), format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.mirostat_eta = std::stof(value); params.sparams.mirostat_eta = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--mirostat-ent"}, "N", {"--mirostat-ent"}, "N",
format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau), format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.mirostat_tau = std::stof(value); params.sparams.mirostat_tau = std::stof(value);
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS", {"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS",
"modifies the likelihood of token appearing in the completion,\n" "modifies the likelihood of token appearing in the completion,\n"
@ -1053,14 +1063,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
throw std::invalid_argument("invalid input format"); throw std::invalid_argument("invalid input format");
} }
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--grammar"}, "GRAMMAR", {"--grammar"}, "GRAMMAR",
format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()), format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.grammar = value; params.sparams.grammar = value;
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--grammar-file"}, "FNAME", {"--grammar-file"}, "FNAME",
"file to read grammar from", "file to read grammar from",
@ -1075,14 +1085,14 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
std::back_inserter(params.sparams.grammar) std::back_inserter(params.sparams.grammar)
); );
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"-j", "--json-schema"}, "SCHEMA", {"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.grammar = json_schema_to_grammar(json::parse(value)); params.sparams.grammar = json_schema_to_grammar(json::parse(value));
} }
)); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"--pooling"}, "{none,mean,cls,last}", {"--pooling"}, "{none,mean,cls,last}",
"pooling type for embeddings, use model default if unspecified", "pooling type for embeddings, use model default if unspecified",
@ -1310,21 +1320,21 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.n_sequences = value; params.n_sequences = value;
} }
)); ).set_examples({LLAMA_EXAMPLE_PARALLEL}));
add_opt(llama_arg( add_opt(llama_arg(
{"-cb", "--cont-batching"}, {"-cb", "--cont-batching"},
format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"), format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"),
[](gpt_params & params) { [](gpt_params & params) {
params.cont_batching = true; params.cont_batching = true;
} }
).set_env("LLAMA_ARG_CONT_BATCHING")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CONT_BATCHING"));
add_opt(llama_arg( add_opt(llama_arg(
{"-nocb", "--no-cont-batching"}, {"-nocb", "--no-cont-batching"},
"disable continuous batching", "disable continuous batching",
[](gpt_params & params) { [](gpt_params & params) {
params.cont_batching = false; params.cont_batching = false;
} }
).set_env("LLAMA_ARG_NO_CONT_BATCHING")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING"));
add_opt(llama_arg( add_opt(llama_arg(
{"--mmproj"}, "FILE", {"--mmproj"}, "FILE",
"path to a multimodal projector file for LLaVA. see examples/llava/README.md", "path to a multimodal projector file for LLaVA. see examples/llava/README.md",
@ -1487,6 +1497,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0 }); params.lora_adapters.push_back({ std::string(value), 1.0 });
} }
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(llama_arg( add_opt(llama_arg(
{"--lora-scaled"}, "FNAME", "SCALE", {"--lora-scaled"}, "FNAME", "SCALE",
@ -1494,6 +1505,7 @@ llama_arg_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & fname, const std::string & scale) { [](gpt_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale) }); params.lora_adapters.push_back({ fname, std::stof(scale) });
} }
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(llama_arg( add_opt(llama_arg(
{"--control-vector"}, "FNAME", {"--control-vector"}, "FNAME",

View file

@ -27,6 +27,8 @@ enum llama_example {
LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_CVECTOR_GENERATOR,
LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_EXPORT_LORA,
LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_COUNT, LLAMA_EXAMPLE_COUNT,
}; };
@ -38,6 +40,7 @@ struct llama_arg {
const char * value_hint_2 = nullptr; // for second arg value const char * value_hint_2 = nullptr; // for second arg value
const char * env = nullptr; const char * env = nullptr;
std::string help; std::string help;
bool is_sparam = false; // is current arg a sampling param?
void (*handler_void) (gpt_params & params) = nullptr; void (*handler_void) (gpt_params & params) = nullptr;
void (*handler_string) (gpt_params & params, const std::string &) = nullptr; void (*handler_string) (gpt_params & params, const std::string &) = nullptr;
void (*handler_str_str)(gpt_params & params, const std::string &, const std::string &) = nullptr; void (*handler_str_str)(gpt_params & params, const std::string &, const std::string &) = nullptr;
@ -74,6 +77,7 @@ struct llama_arg {
llama_arg & set_examples(std::initializer_list<enum llama_example> examples); llama_arg & set_examples(std::initializer_list<enum llama_example> examples);
llama_arg & set_env(const char * env); llama_arg & set_env(const char * env);
llama_arg & set_sparam();
bool in_example(enum llama_example ex); bool in_example(enum llama_example ex);
bool get_value_from_env(std::string & output); bool get_value_from_env(std::string & output);
bool has_value_from_env(); bool has_value_from_env();

View file

@ -13,7 +13,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP);
if (!gpt_params_parse(argc, argv, ctx_arg)) { if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -15,7 +15,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_SPECULATIVE); auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP);
if (!gpt_params_parse(argc, argv, ctx_arg)) { if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -12,7 +12,7 @@
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_LOOKUP);
if (!gpt_params_parse(argc, argv, ctx_arg)) { if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -100,7 +100,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); auto ctx_arg = gpt_params_parser_init(params, LLAMA_EXAMPLE_PARALLEL);
if (!gpt_params_parse(argc, argv, ctx_arg)) { if (!gpt_params_parse(argc, argv, ctx_arg)) {
return 1; return 1;
} }

View file

@ -17,7 +17,7 @@ int main(void) {
auto ctx_arg = gpt_params_parser_init(params, (enum llama_example)ex); auto ctx_arg = gpt_params_parser_init(params, (enum llama_example)ex);
std::unordered_set<std::string> seen_args; std::unordered_set<std::string> seen_args;
std::unordered_set<std::string> seen_env_vars; std::unordered_set<std::string> seen_env_vars;
for (const auto & opt : options) { for (const auto & opt : ctx_arg.options) {
// check for args duplications // check for args duplications
for (const auto & arg : opt.args) { for (const auto & arg : opt.args) {
if (seen_args.find(arg) == seen_args.end()) { if (seen_args.find(arg) == seen_args.end()) {
@ -57,31 +57,31 @@ int main(void) {
printf("test-arg-parser: test invalid usage\n\n"); printf("test-arg-parser: test invalid usage\n\n");
argv = {"binary_name", "-m"}; argv = {"binary_name", "-m"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
argv = {"binary_name", "-ngl", "hello"}; argv = {"binary_name", "-ngl", "hello"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
argv = {"binary_name", "-sm", "hello"}; argv = {"binary_name", "-sm", "hello"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
printf("test-arg-parser: test valid usage\n\n"); printf("test-arg-parser: test valid usage\n\n");
argv = {"binary_name", "-m", "model_file.gguf"}; argv = {"binary_name", "-m", "model_file.gguf"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.model == "model_file.gguf"); assert(params.model == "model_file.gguf");
argv = {"binary_name", "-t", "1234"}; argv = {"binary_name", "-t", "1234"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.cpuparams.n_threads == 1234); assert(params.cpuparams.n_threads == 1234);
argv = {"binary_name", "--verbose"}; argv = {"binary_name", "--verbose"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.verbosity == 1); assert(params.verbosity == 1);
argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"}; argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.model == "abc.gguf"); assert(params.model == "abc.gguf");
assert(params.n_predict == 6789); assert(params.n_predict == 6789);
assert(params.n_batch == 9090); assert(params.n_batch == 9090);
@ -94,12 +94,12 @@ int main(void) {
setenv("LLAMA_ARG_THREADS", "blah", true); setenv("LLAMA_ARG_THREADS", "blah", true);
argv = {"binary_name"}; argv = {"binary_name"};
assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
setenv("LLAMA_ARG_MODEL", "blah.gguf", true); setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true); setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name"}; argv = {"binary_name"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.model == "blah.gguf"); assert(params.model == "blah.gguf");
assert(params.cpuparams.n_threads == 1010); assert(params.cpuparams.n_threads == 1010);
@ -109,7 +109,7 @@ int main(void) {
setenv("LLAMA_ARG_MODEL", "blah.gguf", true); setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
setenv("LLAMA_ARG_THREADS", "1010", true); setenv("LLAMA_ARG_THREADS", "1010", true);
argv = {"binary_name", "-m", "overwritten.gguf"}; argv = {"binary_name", "-m", "overwritten.gguf"};
assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options)); assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), ctx_arg));
assert(params.model == "overwritten.gguf"); assert(params.model == "overwritten.gguf");
assert(params.cpuparams.n_threads == 1010); assert(params.cpuparams.n_threads == 1010);
#endif // _WIN32 #endif // _WIN32