llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
}
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false;
std::string arg;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sparams;
for (int i = 1; i < argc; i++) {
arg = argv[i];
const std::string arg_prefix = "--";
std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
bool invalid_param = false;
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg);
}
@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
get_env("HF_TOKEN", params.hf_token);
}
auto & sparams = params.sparams;
if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
string_process_escapes(params.input_suffix);
string_process_escapes(sparams.cfg_negative_prompt);
for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt);
}
@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.back().key[0] = 0;
}
if (sparams.seed == LLAMA_DEFAULT_SEED) {
sparams.seed = time(NULL);
}
return true;
}
@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
const char split_delim = ',';
llama_sampling_params & sparams = params.sparams;
auto & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") {
CHECK_ARG
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--samplers") {
CHECK_ARG
const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
sparams.samplers = gpt_sampler_types_from_names(sampler_names, true);
return true;
}
if (arg == "--sampling-seq") {
CHECK_ARG
sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
sparams.samplers = gpt_sampler_types_from_chars(argv[i]);
return true;
}
if (arg == "--top-p") {
@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--typical") {
CHECK_ARG
sparams.typical_p = std::stof(argv[i]);
sparams.typ_p = std::stof(argv[i]);
return true;
}
if (arg == "--repeat-last-n") {
@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.mirostat_tau = std::stof(argv[i]);
return true;
}
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
return true;
}
if (arg == "--cfg-negative-prompt-file") {
CHECK_ARG
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
return true;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
sparams.cfg_negative_prompt.pop_back();
}
return true;
}
if (arg == "--cfg-scale") {
CHECK_ARG
sparams.cfg_scale = std::stof(argv[i]);
return true;
}
if (arg == "-b" || arg == "--batch-size") {
CHECK_ARG
params.n_batch = std::stoi(argv[i]);
@ -1355,7 +1333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--ignore-eos") {
params.ignore_eos = true;
sparams.ignore_eos = true;
return true;
}
if (arg == "--penalize-nl") {
@ -1370,7 +1348,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
std::string value_str;
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
sparams.logit_bias.push_back({key, bias});
}
else {
throw std::exception();
@ -1725,13 +1704,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
#endif
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams;
const auto & sparams = params.sparams;
std::string sampler_type_chars;
std::string sampler_type_names;
for (const auto sampler_type : sparams.samplers_sequence) {
sampler_type_chars += static_cast<char>(sampler_type);
sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
for (const auto & sampler : sparams.samplers) {
sampler_type_chars += gpt_sampler_type_to_chr(sampler);
sampler_type_names += gpt_sampler_type_to_str(sampler) + ";";
}
sampler_type_names.pop_back();
@ -1766,7 +1745,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
@ -1846,18 +1824,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
options.push_back({ "sampling" });
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
"(default: %s)", sampler_type_names.c_str() });
options.push_back({ "*", " --sampling-seq SEQUENCE",
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
@ -1872,11 +1851,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
options.push_back({ "main", " --cfg-negative-prompt PROMPT",
"negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
@ -2528,8 +2502,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}
if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
}
if (params.warmup) {
@ -2558,7 +2533,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
}
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
llama_reset_timings(lctx);
llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT);
}
iparams.model = model;
@ -2637,7 +2612,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
@ -3523,7 +3497,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sparams;
const auto & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
@ -3574,8 +3548,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
@ -3586,10 +3558,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
@ -3600,11 +3569,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
continue;
}
fprintf(stream, " %d: %f", lb.first, lb.second);
for (const auto & logit_bias : sparams.logit_bias) {
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
}
fprintf(stream, "lora:\n");
@ -3657,7 +3623,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
@ -3671,7 +3636,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}