Removed xtc_threshold_max
This commit is contained in:
parent
acada1a5e7
commit
9c43a01c5d
11 changed files with 21 additions and 55 deletions
|
@ -987,13 +987,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.sparams.xtc_threshold = std::stof(value);
|
params.sparams.xtc_threshold = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
|
||||||
{"-xtc-t-max", "--xtc-threshold-max"}, "N",
|
|
||||||
format("xtc upper threshold (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_threshold_max),
|
|
||||||
[](common_params & params, const std::string & value) {
|
|
||||||
params.sparams.xtc_threshold_max = std::stof(value);
|
|
||||||
}
|
|
||||||
).set_sparam());
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--typical"}, "N",
|
{"--typical"}, "N",
|
||||||
format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
|
format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
|
||||||
|
|
|
@ -2090,7 +2090,6 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
||||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||||
fprintf(stream, "xtc_probability: %f # default: 0.5\n", sparams.xtc_probability);
|
fprintf(stream, "xtc_probability: %f # default: 0.5\n", sparams.xtc_probability);
|
||||||
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
|
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
|
||||||
fprintf(stream, "xtc_threshold_max: %f # default: 1.0\n", sparams.xtc_threshold_max);
|
|
||||||
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||||
|
|
|
@ -112,7 +112,6 @@ struct common_sampler_params {
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||||
float xtc_threshold = 0.10f; // 0.5 = disabled
|
float xtc_threshold = 0.10f; // 0.5 = disabled
|
||||||
float xtc_threshold_max = 1.00f; // 0.0 = disabled
|
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
|
|
|
@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, xtc_threshold_max = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, xtc_threshold_max, typ_p, temp,
|
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
|
@ -185,7 +185,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.xtc_threshold_max, params.min_keep, params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||||
|
|
|
@ -245,17 +245,14 @@ Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
|
||||||
|
|
||||||
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
|
- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
|
||||||
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
|
- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
|
||||||
- `--xtc-threshold-max N`: Sets a maximum probability threshold for tokens to be removed (highly experimental) (default: 1.0).
|
|
||||||
|
|
||||||
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-p` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
|
Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-p` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one.
|
||||||
|
|
||||||
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
|
By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models.
|
||||||
|
|
||||||
The additional `xtc-threshold-max` parameter may help with finetuned models that already give relatively creative output, meaning that clichés and repetitive phrases may appear at lower probabilities. It allows to remove tokens from a middle range which will always be specific to a model, requiring careful experimenting. Leave `xtc-threshold-max` on default 1.0 for all base/instruct models.
|
|
||||||
|
|
||||||
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 -xtc-p 0.5`.
|
Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 -xtc-p 0.5`.
|
||||||
|
|
||||||
Example usage: `-xtc-p 0.5 -xtc-t 0.1 -xtc-t-max 1.0`
|
Example usage: `-xtc-p 0.5 -xtc-t 0.1
|
||||||
|
|
||||||
### Logit Bias
|
### Logit Bias
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,6 @@
|
||||||
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
||||||
xtc_probability: 0.0, // 0 = disabled;
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
xtc_threshold: 0.1, // 0.5 = disabled;
|
xtc_threshold: 0.1, // 0.5 = disabled;
|
||||||
xtc_threshold_max: 1.0, // 0 = disabled;
|
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
tfs_z: 1.0, // 1.0 = disabled
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -841,7 +840,6 @@ return html`
|
||||||
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
${FloatField({ label: "XTC max threshold", title: "Sets a maximum probability threshold for tokens to be removed (highly experimental)", max: 1.0, min: 0.0, name: "xtc_threshold_max", step: 0.01, value: params.value.xtc_threshold_max })}
|
|
||||||
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
|
||||||
|
@ -1140,7 +1138,6 @@ document.addEventListener('DOMContentLoaded', (event) => {
|
||||||
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
|
min_p: { snapValue: 0.05, snapRangeMultiplier: 2 },
|
||||||
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
||||||
xtc_threshold_max: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
|
||||||
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
|
|
|
@ -309,7 +309,6 @@
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
xtc_probability: 0.0, // 0 = disabled;
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
xtc_threshold: 0.1, // 0.5 = disabled;
|
xtc_threshold: 0.1, // 0.5 = disabled;
|
||||||
xtc_threshold_max: 1.0, // 0 = disabled;
|
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
tfs_z: 1.0, // 1.0 = disabled
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -1018,7 +1017,6 @@
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
${FloatField({ label: "XTC upper threshold", max: 1.0, min: 0.0, name: "xtc_threshold_max", step: 0.01, value: params.value.xtc_threshold_max })}
|
|
||||||
</fieldset>
|
</fieldset>
|
||||||
<hr />
|
<hr />
|
||||||
<fieldset class="three">
|
<fieldset class="three">
|
||||||
|
|
|
@ -893,7 +893,6 @@ struct server_context {
|
||||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
||||||
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
||||||
slot.sparams.xtc_threshold_max = json_value(data, "xtc_threshold_max", default_sparams.xtc_threshold_max);
|
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
|
@ -1244,7 +1243,6 @@ struct server_context {
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
{"xtc_probability", slot.sparams.xtc_probability},
|
{"xtc_probability", slot.sparams.xtc_probability},
|
||||||
{"xtc_threshold", slot.sparams.xtc_threshold},
|
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||||
{"xtc_threshold_max", slot.sparams.xtc_threshold_max},
|
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"tfs_z", slot.sparams.tfs_z},
|
||||||
{"typical_p", slot.sparams.typ_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
|
|
|
@ -1095,7 +1095,7 @@ extern "C" {
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||||
|
|
||||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, float t_max, size_t min_keep, uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
|
|
|
@ -1064,7 +1064,6 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
struct llama_sampler_xtc {
|
struct llama_sampler_xtc {
|
||||||
const float probability;
|
const float probability;
|
||||||
const float threshold;
|
const float threshold;
|
||||||
const float threshold_max;
|
|
||||||
const size_t min_keep;
|
const size_t min_keep;
|
||||||
|
|
||||||
const uint32_t seed;
|
const uint32_t seed;
|
||||||
|
@ -1082,8 +1081,6 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
|
|
||||||
if (ctx->probability <= 0.0f
|
if (ctx->probability <= 0.0f
|
||||||
|| ctx->threshold > 0.5f
|
|| ctx->threshold > 0.5f
|
||||||
|| ctx->threshold_max <= 0.0f
|
|
||||||
|| ctx->threshold_max <= ctx->threshold
|
|
||||||
|| cur_p->size <= 2) {
|
|| cur_p->size <= 2) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1095,35 +1092,29 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
int pos_first = -1;
|
|
||||||
int pos_last = 0;
|
int pos_last = 0;
|
||||||
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
|
if (cur_p->data[i].p - ctx->threshold >= -1e-5) {
|
||||||
if (cur_p->data[i].p - ctx->threshold_max > 1e-3) pos_first = i;
|
|
||||||
pos_last = i;
|
pos_last = i;
|
||||||
} else {
|
} else break;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int to_remove = pos_last - (1 + pos_first);
|
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
|
||||||
|
|
||||||
if (cur_p->size - to_remove >= ctx->min_keep && to_remove > 0) {
|
size_t last_idx = cur_p->size - pos_last;
|
||||||
|
|
||||||
size_t last_idx = cur_p->size - to_remove;
|
for (size_t i = 0; i <= last_idx; ++i) {
|
||||||
|
cur_p->data[i] = cur_p->data[i + pos_last];
|
||||||
for (size_t i = pos_first + 1; i <= last_idx; ++i) {
|
|
||||||
cur_p->data[i] = cur_p->data[i + to_remove];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cur_p->size = cur_p->size - to_remove;
|
cur_p->size = cur_p->size - pos_last;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
|
||||||
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->threshold_max, ctx->min_keep, ctx->seed);
|
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
|
||||||
|
|
||||||
// copy the state
|
// copy the state
|
||||||
{
|
{
|
||||||
|
@ -1154,14 +1145,13 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
/* .free = */ llama_sampler_xtc_free,
|
/* .free = */ llama_sampler_xtc_free,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, float t_max, size_t min_keep, uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
/* .probability = */ p,
|
/* .probability = */ p,
|
||||||
/* .threshold = */ t,
|
/* .threshold = */ t,
|
||||||
/* .threshold_max = */ t_max,
|
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .seed_cur = */ seed_cur,
|
/* .seed_cur = */ seed_cur,
|
||||||
|
|
|
@ -111,7 +111,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t, float t_max) {
|
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
@ -124,7 +124,7 @@ static void test_xtc(const std::vector<float> & probs, const std::vector<float>
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_softmax(), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_xtc(p, t, t_max, 0, 0), &cur_p);
|
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||||
|
@ -306,7 +306,7 @@ static void test_perf() {
|
||||||
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 0.8f, 1, 1), data, 32);
|
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_softmax (), data, 32);
|
BENCH(llama_sampler_init_softmax (), data, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,17 +333,12 @@ int main(void) {
|
||||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||||
|
|
||||||
printf("XTC should:\n");
|
printf("XTC should:\n");
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f, 1.00f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.10f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.1f}, 0.99f, 0.10f, 0.35f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.20f, 1.00f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 1.00f);
|
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.1f}, 0.99f, 0.10f, 0.25f);
|
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.35f);
|
|
||||||
printf("XTC should not:\n");
|
printf("XTC should not:\n");
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.10f, 0.15f);
|
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f);
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.20f, 0.25f);
|
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.30f, 0.35f);
|
|
||||||
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.40f, 1.00f);
|
|
||||||
|
|
||||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue