Merge remote-tracking branch 'upstream/master' into fix-stop-trim
This commit is contained in:
commit
6e13d3faef
22 changed files with 749 additions and 94 deletions
|
@ -130,6 +130,7 @@ Typically finetunes of the base models below are supported as well.
|
|||
- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart)
|
||||
- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326)
|
||||
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
|
||||
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||
|
||||
**UI:**
|
||||
|
||||
|
|
|
@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sparams.tfs_z = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sparams.xtc_probability = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-threshold"}, "N",
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sparams.xtc_threshold = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
|
||||
|
@ -1788,6 +1802,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.n_threads_http = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
||||
add_opt(common_arg(
|
||||
{"--cache-reuse"}, "N",
|
||||
string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse),
|
||||
[](common_params & params, int value) {
|
||||
params.n_cache_reuse = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE"));
|
||||
add_opt(common_arg(
|
||||
{"--metrics"},
|
||||
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),
|
||||
|
|
|
@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
|
||||
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
|
||||
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||
|
|
|
@ -90,6 +90,8 @@ enum common_sampler_type {
|
|||
COMMON_SAMPLER_TYPE_TFS_Z = 4,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
||||
COMMON_SAMPLER_TYPE_XTC = 7,
|
||||
COMMON_SAMPLER_TYPE_INFILL = 8,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
|
@ -108,6 +110,8 @@ struct common_sampler_params {
|
|||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
|
@ -124,13 +128,15 @@ struct common_sampler_params {
|
|||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TFS_Z,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
COMMON_SAMPLER_TYPE_MIN_P,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE
|
||||
COMMON_SAMPLER_TYPE_XTC,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
@ -277,7 +283,8 @@ struct common_params {
|
|||
int32_t port = 8080; // server listens on this network port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
|
|
|
@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
|
|||
|
||||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
return std::string(result);
|
||||
|
@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||
break;
|
||||
|
@ -193,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
|
@ -372,6 +378,8 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
|
@ -384,6 +392,8 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
|
@ -396,6 +406,8 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
@ -441,7 +453,9 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
|
@ -241,6 +241,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
|
|||
|
||||
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
|
||||
|
||||
### XTC Sampling
|
||||
|
||||
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
|
||||
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
|
||||
|
||||
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
|
||||
|
||||
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
|
||||
|
||||
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`.
|
||||
|
||||
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
|
||||
|
||||
### Logit Bias
|
||||
|
||||
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
|
||||
|
|
|
@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
|
|||
if (!params.ctx_shift){
|
||||
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||
break;
|
||||
} else {
|
||||
if (params.n_predict == -2) {
|
||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
LOG_DBG("after swap: n_past = %d\n", n_past);
|
||||
|
||||
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
LOG_DBG("clear session path\n");
|
||||
path_session.clear();
|
||||
}
|
||||
|
||||
if (params.n_predict == -2) {
|
||||
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
LOG_DBG("after swap: n_past = %d\n", n_past);
|
||||
|
||||
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
LOG_DBG("clear session path\n");
|
||||
path_session.clear();
|
||||
}
|
||||
} else {
|
||||
// context extension via Self-Extend
|
||||
|
|
|
@ -147,6 +147,7 @@ The project is under active development, and we are [looking for feedback and co
|
|||
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
|
||||
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
|
||||
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
|
||||
| `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)<br/>(env: LLAMA_ARG_CACHE_REUSE) |
|
||||
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
|
||||
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
|
||||
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
|
||||
|
@ -523,8 +524,31 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
|
|||
|
||||
- `input_prefix`: Set the prefix of the code to infill.
|
||||
- `input_suffix`: Set the suffix of the code to infill.
|
||||
- `input_extra`: Additional context inserted before the FIM prefix.
|
||||
- `prompt`: Added after the `FIM_MID` token
|
||||
|
||||
It also accepts all the options of `/completion`.
|
||||
`input_extra` is array of `{"filename": string, "text": string}` objects.
|
||||
|
||||
The endpoint also accepts all the options of `/completion`.
|
||||
|
||||
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
|
||||
|
||||
```txt
|
||||
<FIM_REP>myproject
|
||||
<FIM_SEP>{chunk 0 filename}
|
||||
{chunk 0 text}
|
||||
<FIM_SEP>{chunk 1 filename}
|
||||
{chunk 1 text}
|
||||
...
|
||||
<FIM_SEP>filename
|
||||
<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
|
||||
```
|
||||
|
||||
If the tokens are missing, then the extra context is simply prefixed at the start:
|
||||
|
||||
```txt
|
||||
[input_extra]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
|
||||
```
|
||||
|
||||
### **GET** `/props`: Get server global properties.
|
||||
|
||||
|
|
|
@ -43,6 +43,8 @@
|
|||
top_k: 0, // <= 0 to use vocab size
|
||||
top_p: 1.0, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
||||
xtc_probability: 0.0, // 0 = disabled;
|
||||
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||
tfs_z: 1.0, // 1.0 = disabled
|
||||
typical_p: 1.0, // 1.0 = disabled
|
||||
presence_penalty: 0.0, // 0.0 = disabled
|
||||
|
@ -836,6 +838,8 @@ return html`
|
|||
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
||||
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||
</fieldset>
|
||||
|
||||
|
@ -1132,6 +1136,8 @@ document.addEventListener('DOMContentLoaded', (event) => {
|
|||
const snapSettings = {
|
||||
temperature: { snapValue: 1.0, snapRangeMultiplier: 6 },
|
||||
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
|
||||
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
||||
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||
|
|
|
@ -307,6 +307,8 @@
|
|||
top_k: 40, // <= 0 to use vocab size
|
||||
top_p: 0.95, // 1.0 = disabled
|
||||
min_p: 0.05, // 0 = disabled
|
||||
xtc_probability: 0.0, // 0 = disabled;
|
||||
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||
tfs_z: 1.0, // 1.0 = disabled
|
||||
typical_p: 1.0, // 1.0 = disabled
|
||||
presence_penalty: 0.0, // 0.0 = disabled
|
||||
|
@ -1013,6 +1015,8 @@
|
|||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||
</fieldset>
|
||||
<hr />
|
||||
<fieldset class="three">
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -136,9 +136,6 @@ struct slot_params {
|
|||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
|
@ -168,8 +165,13 @@ struct server_slot {
|
|||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
json input_extra;
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
|
||||
std::string generated_text;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
|
@ -800,7 +802,7 @@ struct server_context {
|
|||
int slot_prompt_len = slot_prompt.size();
|
||||
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
int lcp_len = common_part(slot_prompt, prompt);
|
||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
||||
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||
|
@ -861,6 +863,8 @@ struct server_context {
|
|||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
||||
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
|
@ -906,11 +910,29 @@ struct server_context {
|
|||
}
|
||||
|
||||
// infill
|
||||
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
||||
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
||||
slot.input_extra = json_value(data, "input_extra", json());
|
||||
|
||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
||||
}
|
||||
|
||||
// get prompt
|
||||
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
||||
{
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
@ -1175,6 +1197,8 @@ struct server_context {
|
|||
{"top_k", slot.sparams.top_k},
|
||||
{"top_p", slot.sparams.top_p},
|
||||
{"min_p", slot.sparams.min_p},
|
||||
{"xtc_probability", slot.sparams.xtc_probability},
|
||||
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||
{"tfs_z", slot.sparams.tfs_z},
|
||||
{"typical_p", slot.sparams.typ_p},
|
||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||
|
@ -1933,26 +1957,88 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
{
|
||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
|
||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
|
||||
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
||||
|
||||
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
|
||||
suffix_tokens.resize(n_suffix_take);
|
||||
slot.extra_tokens.clear();
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
||||
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
|
||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
|
||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
|
||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = tokenize(text, false, false);
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
|
@ -2011,7 +2097,7 @@ struct server_context {
|
|||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
|
@ -2041,12 +2127,59 @@ struct server_context {
|
|||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||
head_p + n_match < prompt_tokens.size() &&
|
||||
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3256,6 +3389,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
|
||||
|
|
|
@ -195,14 +195,14 @@ static std::string gen_chatcmplid() {
|
|||
// other common utils
|
||||
//
|
||||
|
||||
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
||||
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
static size_t common_part(const std::string & a, const std::string & b) {
|
||||
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
|
@ -360,9 +360,9 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (body.contains("logprobs")) {
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
} else if (body.contains("top_logprobs")) {
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
}
|
||||
|
||||
|
|
6
flake.lock
generated
6
flake.lock
generated
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1728018373,
|
||||
"narHash": "sha256-NOiTvBbRLIOe5F6RbHaAh6++BNjsb149fGZd1T4+KBg=",
|
||||
"lastModified": 1728492678,
|
||||
"narHash": "sha256-9UTxR8eukdg+XZeHgxW5hQA9fIKHsKCdOIUycTryeVw=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "bc947f541ae55e999ffdb4013441347d83b00feb",
|
||||
"rev": "5633bcff0c6162b9e4b5f1264264611e950c8ec7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||
|
||||
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
// load 2 halfs into register in a single instruction
|
||||
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = x[ib + iqs + 0];
|
||||
v.y = x[ib + iqs + 1];
|
||||
v.x = __low2float(x_reg);
|
||||
v.y = __high2float(x_reg);
|
||||
}
|
||||
|
||||
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
|
||||
|
@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|||
// matrix multiplication
|
||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||
#ifdef GGML_CUDA_F16
|
||||
tmp += __hmul2(v, {
|
||||
y[iybs + iqs + j/qr + 0],
|
||||
y[iybs + iqs + j/qr + y_offset]
|
||||
});
|
||||
if ( y_offset == 1 ) {
|
||||
// load 2 dfloats into register in a single instruction
|
||||
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||
tmp += __hmul2(v, y_reg);
|
||||
}
|
||||
else {
|
||||
tmp += __hmul2(v, {
|
||||
y[iybs + iqs + j/qr + 0],
|
||||
y[iybs + iqs + j/qr + y_offset]
|
||||
});
|
||||
}
|
||||
#else
|
||||
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
||||
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
||||
if ( y_offset == 1 ) {
|
||||
// load 2 dfloats into register in a single instruction
|
||||
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||
tmp += v.x * y_reg.x;
|
||||
tmp += v.y * y_reg.y;
|
||||
}
|
||||
else {
|
||||
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
||||
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
||||
}
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
}
|
||||
|
|
|
@ -953,6 +953,12 @@ extern "C" {
|
|||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
LLAMA_API bool llama_token_is_prefix(
|
||||
const struct llama_model * model,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
|
||||
/// @param text The char pointer must be large enough to hold the resulting text.
|
||||
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
|
||||
|
@ -1101,6 +1107,9 @@ extern "C" {
|
|||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||
|
||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
|
@ -1145,6 +1154,28 @@ extern "C" {
|
|||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
||||
// this sampler is meant to be used for fill-in-the-middle infilling
|
||||
// it's supposed to be used after top_k + top_p sampling
|
||||
//
|
||||
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
|
||||
// 2. combine probs of tokens that have the same prefix
|
||||
//
|
||||
// example:
|
||||
//
|
||||
// - before:
|
||||
// "hel": 0.5
|
||||
// "hell": 0.2
|
||||
// "hello": 0.1
|
||||
// "dummy": 0.1
|
||||
//
|
||||
// - after:
|
||||
// "hel": 0.8
|
||||
// "dummy": 0.1
|
||||
//
|
||||
// 3. discard non-EOG tokens with low prob
|
||||
// 4. if no tokens are left -> pick EOT
|
||||
//
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
|
|
@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|||
};
|
||||
}
|
||||
|
||||
// xtc
|
||||
|
||||
struct llama_sampler_xtc {
|
||||
const float probability;
|
||||
const float threshold;
|
||||
const size_t min_keep;
|
||||
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "xtc";
|
||||
}
|
||||
|
||||
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
|
||||
if (ctx->probability <= 0.0f
|
||||
|| ctx->threshold > 0.5f
|
||||
|| cur_p->size < 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
|
||||
float chance = distribution(ctx->rng);
|
||||
if (chance > ctx->probability) return;
|
||||
|
||||
// in case it's not sorted/recalculated yet
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
int pos_last = 0;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].p >= ctx->threshold) {
|
||||
pos_last = i;
|
||||
} else break;
|
||||
}
|
||||
|
||||
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
||||
cur_p->data += pos_last;
|
||||
cur_p->size -= pos_last;
|
||||
}
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_xtc *) result->ctx;
|
||||
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_xtc *) smpl->ctx;
|
||||
}
|
||||
|
||||
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||
/* .name = */ llama_sampler_xtc_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sample_xtc_apply,
|
||||
/* .reset = */ llama_sampler_xtc_reset,
|
||||
/* .clone = */ llama_sampler_xtc_clone,
|
||||
/* .free = */ llama_sampler_xtc_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_xtc_i,
|
||||
/* .ctx = */ new llama_sampler_xtc {
|
||||
/* .probability = */ p,
|
||||
/* .threshold = */ t,
|
||||
/* .min_keep = */ min_keep,
|
||||
/* .seed = */ seed,
|
||||
/* .seed_cur = */ seed_cur,
|
||||
/* .rng = */ std::mt19937(seed_cur),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// mirostat
|
||||
|
||||
struct llama_sampler_mirostat {
|
||||
|
@ -1644,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||
};
|
||||
}
|
||||
|
||||
// infill
|
||||
|
||||
//#define GGML_DEBUG_SAMPLER_INFILL
|
||||
|
||||
struct llama_sampler_infill {
|
||||
const struct llama_vocab * vocab;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "infill";
|
||||
}
|
||||
|
||||
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_infill *) smpl->ctx;
|
||||
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
#if defined(GGML_DEBUG_SAMPLER_INFILL)
|
||||
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
|
||||
#else
|
||||
#define LOG_DBG_CUR(...)
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
float p_txt_sum = 0.0f;
|
||||
float p_eog_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
p_eog_sum += cur_p->data[i].p;
|
||||
} else {
|
||||
p_txt_sum += cur_p->data[i].p;
|
||||
}
|
||||
}
|
||||
|
||||
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
|
||||
|
||||
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
|
||||
|
||||
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
|
||||
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
|
||||
|
||||
// keep just the EOG tokens
|
||||
const auto size_org = cur_p->size;
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
float p_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||
|
||||
// combine tokens with common prefix
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
for (size_t j = 0; j < cur_p->size; ++j) {
|
||||
if (cur_p->data[i].logit == -INFINITY) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i == j || cur_p->data[j].logit == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
||||
if (cur_p->data[i].p > cur_p->data[j].p) {
|
||||
cur_p->data[i].p += cur_p->data[j].p;
|
||||
cur_p->data[j].logit = -INFINITY;
|
||||
cur_p->data[j].p = 0.0f;
|
||||
} else {
|
||||
cur_p->data[j].p += cur_p->data[i].p;
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
cur_p->data[i].p = 0.0f;
|
||||
}
|
||||
|
||||
n_combined++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t n_non_eog = 0;
|
||||
|
||||
size_t size_org = cur_p->size;
|
||||
|
||||
float p_sum = 0.0f;
|
||||
float thold = 0.2f;
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!is_eog) {
|
||||
++n_non_eog;
|
||||
}
|
||||
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
// keep this token
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
|
||||
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
|
||||
|
||||
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
||||
if (n_non_eog == 0) {
|
||||
cur_p->size = 1;
|
||||
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
||||
cur_p->data[0].logit = 1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
size_org = cur_p->size;
|
||||
p_sum = 0.0f;
|
||||
thold = 1.0/(n_non_eog + 1);
|
||||
|
||||
cur_p->size = 0;
|
||||
|
||||
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
}
|
||||
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= p_sum;
|
||||
|
||||
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
||||
}
|
||||
|
||||
#undef LOG_DBG_CUR
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
||||
return llama_sampler_init_infill_impl(*ctx->vocab);
|
||||
}
|
||||
|
||||
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_infill *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_infill_i = {
|
||||
/* .name = */ llama_sampler_infill_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_infill_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_infill_clone,
|
||||
/* .free = */ llama_sampler_infill_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_infill_i,
|
||||
/* .ctx = */ new llama_sampler_infill {
|
||||
/* .vocab = */ &vocab,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// utils
|
||||
|
||||
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
struct llama_vocab;
|
||||
struct llama_grammar;
|
||||
|
||||
|
@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||
const struct llama_vocab & vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab);
|
||||
|
|
|
@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||
return 0;
|
||||
}
|
||||
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1) {
|
||||
char text_buf_0[128];
|
||||
char text_buf_1[128];
|
||||
|
||||
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
|
||||
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
|
||||
|
||||
if (len0 <= 0 || len1 <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
|
||||
}
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
|
|
|
@ -48,7 +48,7 @@ struct llama_vocab {
|
|||
id special_cls_id = LLAMA_TOKEN_NULL;
|
||||
id special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
id linefeed_id = 13;
|
||||
id linefeed_id = 13;
|
||||
|
||||
// fim tokens
|
||||
id special_fim_pre_id = LLAMA_TOKEN_NULL;
|
||||
|
@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
|
|||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
|
|
|
@ -6596,8 +6596,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_eot_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6610,8 +6610,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_eom_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6627,8 +6627,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_pre_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6644,8 +6644,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_suf_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6661,8 +6661,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_mid_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6677,8 +6677,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_pad_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6694,8 +6694,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_rep_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6708,8 +6708,8 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_fim_sep_id = t.second;
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
}
|
||||
|
@ -6720,6 +6720,19 @@ static void llm_load_vocab(
|
|||
// this is currently determined based on the token text, which is obviously not ideal
|
||||
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
|
||||
vocab.special_eog_ids.clear();
|
||||
|
||||
if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
|
||||
vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
|
||||
}
|
||||
|
||||
if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
|
||||
vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
|
||||
}
|
||||
|
||||
if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
|
||||
vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
|
||||
}
|
||||
|
||||
for (const auto & t : vocab.token_to_id) {
|
||||
if (false
|
||||
|| t.first == "<|eot_id|>"
|
||||
|
@ -6732,13 +6745,20 @@ static void llm_load_vocab(
|
|||
) {
|
||||
vocab.special_eog_ids.insert(t.second);
|
||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.first.c_str());
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
}
|
||||
} else {
|
||||
// token is control, but not marked as EOG -> print a warning
|
||||
if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control token: %6d '%s' is not marked as EOG\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanity checks
|
||||
if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
|
||||
vocab.special_eog_ids.insert(vocab.special_eos_id);
|
||||
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||
|
@ -21480,6 +21500,13 @@ int32_t llama_token_to_piece(
|
|||
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
|
||||
}
|
||||
|
||||
bool llama_token_is_prefix(
|
||||
const struct llama_model * model,
|
||||
llama_token token0,
|
||||
llama_token token1) {
|
||||
return llama_token_is_prefix_impl(model->vocab, token0, token1);
|
||||
}
|
||||
|
||||
int32_t llama_detokenize(
|
||||
const struct llama_model * model,
|
||||
const llama_token * tokens,
|
||||
|
@ -21810,6 +21837,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
|
|||
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
|
||||
return llama_sampler_init_infill_impl(model->vocab);
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
|
|
@ -111,6 +111,28 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
}
|
||||
|
||||
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
|
||||
const size_t n_vocab = probs.size();
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
const float logit = logf(probs[token_id]);
|
||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
|
||||
DUMP(&cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
|
||||
|
@ -263,7 +285,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
|
|||
}
|
||||
const int64_t t_end = ggml_time_us();
|
||||
llama_sampler_free(cnstr);
|
||||
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
||||
printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
||||
}
|
||||
|
||||
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
||||
|
@ -279,12 +301,13 @@ static void test_perf() {
|
|||
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||
}
|
||||
|
||||
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_softmax (), data, 32);
|
||||
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
||||
BENCH(llama_sampler_init_softmax (), data, 32);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
@ -309,6 +332,14 @@ int main(void) {
|
|||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||
|
||||
printf("XTC should:\n");
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
|
||||
|
||||
printf("XTC should not:\n");
|
||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
|
||||
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue