llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
parent
947538acb8
commit
df270ef745
48 changed files with 3497 additions and 2914 deletions
|
@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
|
|||
}
|
||||
|
||||
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
const std::string arg_prefix = "--";
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
std::string arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
|
||||
bool invalid_param = false;
|
||||
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
|
||||
throw std::invalid_argument("error: unknown argument: " + arg);
|
||||
}
|
||||
|
@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
get_env("HF_TOKEN", params.hf_token);
|
||||
}
|
||||
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
if (params.escape) {
|
||||
string_process_escapes(params.prompt);
|
||||
string_process_escapes(params.input_prefix);
|
||||
string_process_escapes(params.input_suffix);
|
||||
string_process_escapes(sparams.cfg_negative_prompt);
|
||||
for (auto & antiprompt : params.antiprompt) {
|
||||
string_process_escapes(antiprompt);
|
||||
}
|
||||
|
@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
if (sparams.seed == LLAMA_DEFAULT_SEED) {
|
||||
sparams.seed = time(NULL);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
|||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||
const char split_delim = ',';
|
||||
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
if (arg == "-s" || arg == "--seed") {
|
||||
CHECK_ARG
|
||||
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
|
||||
params.seed = std::stoul(argv[i]);
|
||||
sparams.seed = std::stoul(argv[i]);
|
||||
return true;
|
||||
}
|
||||
|
@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
if (arg == "--samplers") {
|
||||
CHECK_ARG
|
||||
const auto sampler_names = string_split(argv[i], ';');
|
||||
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
|
||||
sparams.samplers = gpt_sampler_types_from_names(sampler_names, true);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--sampling-seq") {
|
||||
CHECK_ARG
|
||||
sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
|
||||
sparams.samplers = gpt_sampler_types_from_chars(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--top-p") {
|
||||
|
@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
}
|
||||
if (arg == "--typical") {
|
||||
CHECK_ARG
|
||||
sparams.typical_p = std::stof(argv[i]);
|
||||
sparams.typ_p = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--repeat-last-n") {
|
||||
|
@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
sparams.mirostat_tau = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--cfg-negative-prompt") {
|
||||
CHECK_ARG
|
||||
sparams.cfg_negative_prompt = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "--cfg-negative-prompt-file") {
|
||||
CHECK_ARG
|
||||
std::ifstream file(argv[i]);
|
||||
if (!file) {
|
||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
|
||||
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
|
||||
sparams.cfg_negative_prompt.pop_back();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--cfg-scale") {
|
||||
CHECK_ARG
|
||||
sparams.cfg_scale = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-b" || arg == "--batch-size") {
|
||||
CHECK_ARG
|
||||
params.n_batch = std::stoi(argv[i]);
|
||||
|
@ -1355,7 +1333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
return true;
|
||||
}
|
||||
if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
sparams.ignore_eos = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--penalize-nl") {
|
||||
|
@ -1370,7 +1348,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
std::string value_str;
|
||||
try {
|
||||
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
|
||||
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
||||
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
||||
sparams.logit_bias.push_back({key, bias});
|
||||
}
|
||||
else {
|
||||
throw std::exception();
|
||||
|
@ -1725,13 +1704,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
#endif
|
||||
|
||||
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
const auto & sparams = params.sparams;
|
||||
|
||||
std::string sampler_type_chars;
|
||||
std::string sampler_type_names;
|
||||
for (const auto sampler_type : sparams.samplers_sequence) {
|
||||
sampler_type_chars += static_cast<char>(sampler_type);
|
||||
sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
|
||||
for (const auto & sampler : sparams.samplers) {
|
||||
sampler_type_chars += gpt_sampler_type_to_chr(sampler);
|
||||
sampler_type_names += gpt_sampler_type_to_str(sampler) + ";";
|
||||
}
|
||||
sampler_type_names.pop_back();
|
||||
|
||||
|
@ -1766,7 +1745,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
|
||||
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
|
||||
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
|
||||
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
|
||||
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
|
||||
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
|
||||
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
|
||||
|
@ -1846,18 +1824,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
|
||||
|
||||
options.push_back({ "sampling" });
|
||||
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
|
||||
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
|
||||
"(default: %s)", sampler_type_names.c_str() });
|
||||
options.push_back({ "*", " --sampling-seq SEQUENCE",
|
||||
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
|
||||
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
|
||||
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
|
||||
options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
|
||||
options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
|
||||
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
|
||||
options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
|
||||
options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
|
||||
options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
|
||||
options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
|
||||
options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
|
||||
options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
|
||||
options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
|
||||
options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
|
||||
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
|
||||
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
|
||||
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
|
||||
|
@ -1872,11 +1851,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
|
||||
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
|
||||
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
|
||||
options.push_back({ "main", " --cfg-negative-prompt PROMPT",
|
||||
"negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
|
||||
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
|
||||
"negative prompt file to use for guidance" });
|
||||
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
|
||||
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
|
||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
|
@ -2528,8 +2502,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
|
||||
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sparams.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
|
@ -2558,7 +2533,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
}
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_reset_timings(lctx);
|
||||
llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
}
|
||||
|
||||
iparams.model = model;
|
||||
|
@ -2637,7 +2612,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.n_threads = params.cpuparams.n_threads;
|
||||
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||
cparams.seed = params.seed;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embeddings = params.embedding;
|
||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||
|
@ -3523,7 +3497,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
|||
|
||||
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
const auto & sparams = params.sparams;
|
||||
|
||||
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
||||
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
||||
|
@ -3574,8 +3548,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
|
||||
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
|
||||
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
|
||||
yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
|
||||
fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
|
||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||
|
@ -3586,10 +3558,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||
|
||||
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
|
||||
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
|
||||
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
|
||||
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
|
||||
|
||||
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
|
||||
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
|
||||
|
@ -3600,11 +3569,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
|
||||
|
||||
fprintf(stream, "logit_bias:\n");
|
||||
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
|
||||
if (ignore_eos && lb.first == logit_bias_eos->first) {
|
||||
continue;
|
||||
}
|
||||
fprintf(stream, " %d: %f", lb.first, lb.second);
|
||||
for (const auto & logit_bias : sparams.logit_bias) {
|
||||
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
|
||||
}
|
||||
|
||||
fprintf(stream, "lora:\n");
|
||||
|
@ -3657,7 +3623,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
|
||||
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
|
||||
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
|
||||
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
||||
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||
|
@ -3671,7 +3636,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
||||
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue