First fixes by comments
Still need to look into sorting
This commit is contained in:
parent
db54ac5df4
commit
41e16654bd
3 changed files with 16 additions and 10 deletions
|
@ -975,7 +975,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--xtc-t"}, "N",
|
{"--xtc-t"}, "N",
|
||||||
format("xtc threshold (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_t),
|
format("xtc threshold (default: %.1f, 0.0 or 1.0 = disabled)", (double)params.sparams.xtc_t),
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.sparams.xtc_t = std::stof(value);
|
params.sparams.xtc_t = std::stof(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,7 +110,7 @@ struct gpt_sampler_params {
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float xtc_p = 0.50f; // 0.0 = disabled
|
float xtc_p = 0.50f; // 0.0 = disabled
|
||||||
float xtc_t = 0.10f; // 1.0 = disabled
|
float xtc_t = 0.10f; // 0.0 or 1.0 = disabled
|
||||||
float xtc_t_max = 1.00f; // 0.0 = disabled
|
float xtc_t_max = 1.00f; // 0.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
|
|
|
@ -1075,37 +1075,43 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
|
||||||
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->probability <= 0.0f || ctx->threshold <= 0.0f || cur_p->size <= 1 || ctx->min_keep <= 2) {
|
if (ctx->probability <= 0.0f
|
||||||
|
|| ctx->threshold <= 0.0f
|
||||||
|
|| ctx->threshold >= 1.0f
|
||||||
|
|| ctx->threshold_max <= 0.0f
|
||||||
|
|| ctx->threshold_max <= ctx->threshold
|
||||||
|
|| cur_p->size <= 2
|
||||||
|
|| ctx->min_keep <= 2) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
float chance = (float)(rd()%100)/100;
|
float chance = (float)(rd()%100 - 1)/100;
|
||||||
if (chance > ctx->probability) return;
|
if (chance > ctx->probability) return;
|
||||||
|
|
||||||
// in case it's not sorted/recalculated yet
|
// in case it's not sorted/recalculated yet
|
||||||
llama_sampler_softmax_impl(cur_p);
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
int removed = 0;
|
int found = 0;
|
||||||
// going through all candidates from back to front, easier to keep the last of probables
|
// going through all candidates from back to front, easier to keep the last of probables
|
||||||
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
for (int i = (cur_p->size - 1); i >= 0; --i) {
|
||||||
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
|
||||||
++removed;
|
++found;
|
||||||
if (removed > 1) {
|
if (found > 1) {
|
||||||
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
|
||||||
cur_p->data[i].logit = -999.0f;
|
cur_p->data[i].logit = -999.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (removed > 1) {
|
if (found > 1) {
|
||||||
// sorting with new logits, ex-last probable will be the first anyway
|
// sorting with new logits, ex-last probable will be the first anyway
|
||||||
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
return a.logit > b.logit;
|
return a.logit > b.logit;
|
||||||
});
|
});
|
||||||
cur_p->sorted = true;
|
|
||||||
|
|
||||||
// resizing now that penalized tokens are at the back
|
// resizing now that penalized tokens are at the back
|
||||||
cur_p->size = cur_p->size - removed + 1;
|
cur_p->size = cur_p->size - found + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue