llama : add infill sampler (#9896)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-15 16:35:33 +03:00 committed by GitHub
parent 223c25a72f
commit 755a9b2bf0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 300 additions and 29 deletions

View file

@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
};
}
// infill
//#define GGML_DEBUG_SAMPLER_INFILL
struct llama_sampler_infill {
const struct llama_vocab * vocab;
};
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
return "infill";
}
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
#if defined(GGML_DEBUG_SAMPLER_INFILL)
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
#else
#define LOG_DBG_CUR(...)
#endif
for (size_t i = 0; i < cur_p->size; ++i) {
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
float p_txt_sum = 0.0f;
float p_eog_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
} else {
p_txt_sum += cur_p->data[i].p;
}
}
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
// keep just the EOG tokens
const auto size_org = cur_p->size;
cur_p->size = 0;
float p_sum = 0.0f;
for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i];
}
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
}
return;
}
size_t n_combined = 0; GGML_UNUSED(n_combined);
// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
break;
}
if (i == j || cur_p->data[j].logit == -INFINITY) {
continue;
}
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
}
n_combined++;
}
}
}
size_t n_non_eog = 0;
size_t size_org = cur_p->size;
float p_sum = 0.0f;
float thold = 0.2f;
cur_p->size = 0;
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) {
continue;
}
if (!is_eog) {
++n_non_eog;
}
p_sum += cur_p->data[i].p;
// keep this token
cur_p->data[cur_p->size++] = cur_p->data[i];
}
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;
return;
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
size_org = cur_p->size;
p_sum = 0.0f;
thold = 1.0/(n_non_eog + 1);
cur_p->size = 0;
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) {
continue;
}
p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i];
}
// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
#undef LOG_DBG_CUR
}
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab);
}
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
delete (llama_sampler_infill *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
};
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
},
};
}
// utils
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {