fix setting prob for sampled token
This commit is contained in:
parent
a217382b25
commit
5b966df177
1 changed files with 21 additions and 24 deletions
|
@ -475,14 +475,14 @@ struct completion_token_output {
|
||||||
{"id", p.tok},
|
{"id", p.tok},
|
||||||
{"token", txt},
|
{"token", txt},
|
||||||
{"bytes", str_to_bytes(p.text_to_send)},
|
{"bytes", str_to_bytes(p.text_to_send)},
|
||||||
{
|
|
||||||
post_sampling_probs ? "top_probs" : "top_logprobs",
|
|
||||||
p.to_json(post_sampling_probs)
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
post_sampling_probs ? "prob" : "logprob",
|
post_sampling_probs ? "prob" : "logprob",
|
||||||
post_sampling_probs ? p.prob : logarithm(p.prob)
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
post_sampling_probs ? "top_probs" : "top_logprobs",
|
||||||
|
p.to_json(post_sampling_probs)
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
@ -1956,52 +1956,49 @@ struct server_context {
|
||||||
|
|
||||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
||||||
size_t n_probs = slot.params.sampling.n_probs;
|
size_t n_probs = slot.params.sampling.n_probs;
|
||||||
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
if (post_sampling) {
|
if (post_sampling) {
|
||||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||||
const size_t max_probs = cur_p->size;
|
const size_t max_probs = cur_p->size;
|
||||||
|
|
||||||
bool found_sampled_tok = false;
|
// set probability for sampled token
|
||||||
result.probs.reserve(max_probs);
|
|
||||||
for (size_t i = 0; i < max_probs; i++) {
|
for (size_t i = 0; i < max_probs; i++) {
|
||||||
// set probability for sampled token
|
|
||||||
if (cur_p->data[i].id == result.tok) {
|
if (cur_p->data[i].id == result.tok) {
|
||||||
found_sampled_tok = true;
|
|
||||||
result.prob = cur_p->data[i].p;
|
result.prob = cur_p->data[i].p;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
// set probability for top n_probs tokens
|
}
|
||||||
|
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.reserve(max_probs);
|
||||||
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur_p->data[i].id,
|
cur_p->data[i].id,
|
||||||
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||||
cur_p->data[i].p
|
cur_p->data[i].p
|
||||||
});
|
});
|
||||||
// break if we have all the necessary data
|
|
||||||
if (result.probs.size() == n_probs && found_sampled_tok) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: optimize this with min-p optimization
|
// TODO: optimize this with min-p optimization
|
||||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||||
|
|
||||||
bool found_sampled_tok = false;
|
// set probability for sampled token
|
||||||
result.probs.reserve(n_probs);
|
for (size_t i = 0; i < n_vocab; i++) {
|
||||||
for (int i = 0; i < n_vocab; i++) {
|
|
||||||
// set probability for sampled token
|
// set probability for sampled token
|
||||||
if (cur[i].id == result.tok) {
|
if (cur[i].id == result.tok) {
|
||||||
found_sampled_tok = true;
|
|
||||||
result.prob = cur[i].p;
|
result.prob = cur[i].p;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
// set probability for top n_probs tokens
|
}
|
||||||
|
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.reserve(n_probs);
|
||||||
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur[i].id,
|
cur[i].id,
|
||||||
common_detokenize(ctx, {cur[i].id}, special),
|
common_detokenize(ctx, {cur[i].id}, special),
|
||||||
cur[i].p
|
cur[i].p
|
||||||
});
|
});
|
||||||
// break if we have all the necessary data
|
|
||||||
if (result.probs.size() == n_probs && found_sampled_tok) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2894,7 +2891,7 @@ struct server_context {
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||||
result.prob = 1.0f; // set later
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue