First fixes by comments

Still need to look into sorting
This commit is contained in:
MaggotHATE 2024-10-04 21:34:31 +05:00 committed by GitHub
parent db54ac5df4
commit 41e16654bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 10 deletions

View file

@ -975,7 +975,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
).set_sparam());
add_opt(llama_arg(
{"--xtc-t"}, "N",
format("xtc threshold (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_t),
format("xtc threshold (default: %.1f, 0.0 or 1.0 = disabled)", (double)params.sparams.xtc_t),
[](gpt_params & params, const std::string & value) {
params.sparams.xtc_t = std::stof(value);
}

View file

@ -110,7 +110,7 @@ struct gpt_sampler_params {
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_p = 0.50f; // 0.0 = disabled
float xtc_t = 0.10f; // 1.0 = disabled
float xtc_t = 0.10f; // 0.0 or 1.0 = disabled
float xtc_t_max = 1.00f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled

View file

@ -1075,37 +1075,43 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
if (ctx->probability <= 0.0f || ctx->threshold <= 0.0f || cur_p->size <= 1 || ctx->min_keep <= 2) {
if (ctx->probability <= 0.0f
|| ctx->threshold <= 0.0f
|| ctx->threshold >= 1.0f
|| ctx->threshold_max <= 0.0f
|| ctx->threshold_max <= ctx->threshold
|| cur_p->size <= 2
|| ctx->min_keep <= 2) {
return;
}
std::random_device rd;
float chance = (float)(rd()%100)/100;
float chance = (float)(rd()%100 - 1)/100;
if (chance > ctx->probability) return;
// in case it's not sorted/recalculated yet
llama_sampler_softmax_impl(cur_p);
int removed = 0;
int found = 0;
// going through all candidates from back to front, easier to keep the last of probables
for (int i = (cur_p->size - 1); i >= 0; --i) {
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
++removed;
if (removed > 1) {
++found;
if (found > 1) {
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
cur_p->data[i].logit = -999.0f;
}
}
}
if (removed > 1) {
if (found > 1) {
// sorting with new logits, ex-last probable will be the first anyway
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
cur_p->sorted = true;
// resizing now that penalized tokens are at the back
cur_p->size = cur_p->size - removed + 1;
cur_p->size = cur_p->size - found + 1;
}
}