Merge branch 'ggerganov:master' into master

This commit is contained in:
dennyxbox890 2024-10-15 21:34:54 +08:00 committed by GitHub
commit a0c403e4f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 230 additions and 61 deletions

View file

@ -130,6 +130,7 @@ Typically finetunes of the base models below are supported as well.
- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart)
- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326)
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
**UI:**

View file

@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.tfs_z = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
[](common_params & params, const std::string & value) {
params.sparams.xtc_probability = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sparams.xtc_threshold = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),

View file

@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");

View file

@ -90,6 +90,8 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TFS_Z = 4,
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,
};
// dimensionality reduction methods, used by cvector-generator
@ -108,6 +110,8 @@ struct common_sampler_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float tfs_z = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
@ -124,12 +128,14 @@ struct common_sampler_params {
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TFS_Z,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE
};

View file

@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
top_k, tfs_z, top_p, min_p, typ_p, temp,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau);
return std::string(result);
@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
break;
@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
default : return '?';
}
}
@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
default : return "";
}
}
@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
};
// since samplers names are written multiple ways
@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
};
std::vector<common_sampler_type> samplers;

View file

@ -241,6 +241,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
### XTC Sampling
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
### Logit Bias
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.

View file

@ -524,10 +524,12 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
- `input_prefix`: Set the prefix of the code to infill.
- `input_suffix`: Set the suffix of the code to infill.
- `prompt`: Added after the `FIM_MID` token
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
- `input_extra`: Additional context inserted before the FIM prefix.
- `prompt`: Added after the `FIM_MID` token
It also accepts all the options of `/completion`.
`input_extra` is array of `{"filename": string, "text": string}` objects.
The endpoint also accepts all the options of `/completion`.
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
@ -545,7 +547,7 @@ If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](
If the tokens are missing, then the extra context is simply prefixed at the start:
```txt
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
[input_extra]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```
### **GET** `/props`: Get server global properties.

View file

@ -43,6 +43,8 @@
top_k: 0, // <= 0 to use vocab size
top_p: 1.0, // 1.0 = disabled
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled
@ -836,6 +838,8 @@ return html`
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
</fieldset>
@ -1132,6 +1136,8 @@ document.addEventListener('DOMContentLoaded', (event) => {
const snapSettings = {
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },

View file

@ -307,6 +307,8 @@
top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled
xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled
@ -1013,6 +1015,8 @@
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
</fieldset>
<hr />
<fieldset class="three">

File diff suppressed because one or more lines are too long

View file

@ -136,10 +136,6 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt;
json input_prefix;
json input_suffix;
json extra_context;
};
struct server_slot {
@ -169,6 +165,10 @@ struct server_slot {
json prompt; // can be either a string, array of strings or array of token ids
json input_prefix;
json input_suffix;
json input_extra;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
@ -863,6 +863,8 @@ struct server_context {
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
@ -908,12 +910,12 @@ struct server_context {
}
// infill
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
slot.input_prefix = json_value(data, "input_prefix", json());
slot.input_suffix = json_value(data, "input_suffix", json());
slot.input_extra = json_value(data, "input_extra", json());
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
for (const auto & chunk : slot.params.extra_context) {
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
@ -930,7 +932,7 @@ struct server_context {
}
// get prompt
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
{
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@ -1196,6 +1198,8 @@ struct server_context {
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"xtc_probability", slot.sparams.xtc_probability},
{"xtc_threshold", slot.sparams.xtc_threshold},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
@ -1954,6 +1958,8 @@ struct server_context {
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
@ -1964,10 +1970,11 @@ struct server_context {
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
auto tokens_prompt = tokenize(slot.prompt, false, false);
slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
@ -1977,7 +1984,7 @@ struct server_context {
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : slot.params.extra_context) {
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
@ -2008,20 +2015,21 @@ struct server_context {
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
suffix_tokens.resize(n_suffix_take);
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
@ -2136,40 +2144,17 @@ struct server_context {
while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
break;
}
if (llama_token_is_control(model, prompt_tokens[head_p]) &&
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
break;
}
size_t n_match = 0;
while (head_c + n_match < slot.cache_tokens.size() &&
head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
break;
}
if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
break;
}
n_match++;
}
if (n_match >= (size_t) params.n_cache_reuse) {
SLT_DBG(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
//}

View file

@ -1101,6 +1101,9 @@ extern "C" {
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

View file

@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
};
}
// xtc
struct llama_sampler_xtc {
const float probability;
const float threshold;
const size_t min_keep;
const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng;
};
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
return "xtc";
}
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
if (ctx->probability <= 0.0f
|| ctx->threshold > 0.5f
|| cur_p->size < 2) {
return;
}
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
float chance = distribution(ctx->rng);
if (chance > ctx->probability) return;
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
int pos_last = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p >= ctx->threshold) {
pos_last = i;
} else break;
}
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
cur_p->data += pos_last;
cur_p->size -= pos_last;
}
}
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
// copy the state
{
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
result_ctx->rng = ctx->rng;
}
return result;
}
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
delete (llama_sampler_xtc *) smpl->ctx;
}
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static struct llama_sampler_i llama_sampler_xtc_i = {
/* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply,
/* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free,
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
/* .threshold = */ t,
/* .min_keep = */ min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
// mirostat
struct llama_sampler_mirostat {

View file

@ -111,6 +111,28 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
}
}
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
const size_t n_vocab = probs.size();
std::vector<llama_token_data> cur;
cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
APPLY(llama_sampler_init_softmax(), &cur_p);
DUMP(&cur_p);
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
DUMP(&cur_p);
GGML_ASSERT(cur_p.size == expected_probs.size());
for (size_t i = 0; i < cur_p.size; i++) {
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
}
}
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
@ -263,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
}
const int64_t t_end = ggml_time_us();
llama_sampler_free(cnstr);
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
}
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
@ -279,12 +301,13 @@ static void test_perf() {
data.emplace_back(llama_token_data{i, logit, 0.0f});
}
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_softmax (), data, 32);
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
BENCH(llama_sampler_init_softmax (), data, 32);
}
int main(void) {
@ -309,6 +332,14 @@ int main(void) {
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
printf("XTC should:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
printf("XTC should not:\n");
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);