llama : infill sampling handle very long tokens (#9924)
* llama : infill sampling handle very long tokens ggml-ci * cont : better indices ggml-ci
This commit is contained in:
parent
3752217ed5
commit
99bd4ac28c
4 changed files with 35 additions and 43 deletions
|
@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||
|
||||
struct llama_sampler_infill {
|
||||
const struct llama_vocab * vocab;
|
||||
|
||||
std::vector<char> buf0;
|
||||
std::vector<char> buf1;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|||
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||
|
||||
// combine tokens with common prefix
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
for (size_t j = 0; j < cur_p->size; ++j) {
|
||||
if (cur_p->data[i].logit == -INFINITY) {
|
||||
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
|
||||
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
|
||||
if (cur_p->data[i0].logit == -INFINITY) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i == j || cur_p->data[j].logit == -INFINITY) {
|
||||
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
||||
if (cur_p->data[i].p > cur_p->data[j].p) {
|
||||
cur_p->data[i].p += cur_p->data[j].p;
|
||||
cur_p->data[j].logit = -INFINITY;
|
||||
cur_p->data[j].p = 0.0f;
|
||||
} else {
|
||||
cur_p->data[j].p += cur_p->data[i].p;
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
cur_p->data[i].p = 0.0f;
|
||||
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
if (len0 < 0) {
|
||||
ctx->buf0.resize(len0);
|
||||
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
assert(len0 > 0);
|
||||
}
|
||||
|
||||
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
if (len1 < 0) {
|
||||
ctx->buf1.resize(len1);
|
||||
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
assert(len1 > 0);
|
||||
}
|
||||
|
||||
// token i0 is a prefix of token i1
|
||||
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
|
||||
int dst = i0;
|
||||
int src = i1;
|
||||
|
||||
// merge into the token with higher probability
|
||||
if (cur_p->data[i1].p > cur_p->data[i0].p) {
|
||||
std::swap(dst, src);
|
||||
}
|
||||
|
||||
cur_p->data[dst].p += cur_p->data[src].p;
|
||||
cur_p->data[src].logit = -INFINITY;
|
||||
cur_p->data[src].p = 0.0f;
|
||||
|
||||
n_combined++;
|
||||
}
|
||||
}
|
||||
|
@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
|
|||
/* .iface = */ &llama_sampler_infill_i,
|
||||
/* .ctx = */ new llama_sampler_infill {
|
||||
/* .vocab = */ &vocab,
|
||||
/* .buf0 = */ std::vector<char>(512),
|
||||
/* .buf1 = */ std::vector<char>(512),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue