set p for sampled token
This commit is contained in:
parent
22b72c8574
commit
ed7f2d5756
1 changed files with 14 additions and 3 deletions
|
@ -1889,15 +1889,26 @@ struct server_context {
|
||||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
|
||||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||||
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
size_t n_probs = slot.params.sampling.n_probs;
|
||||||
|
|
||||||
// only take at most n_probs tokens
|
bool found_sampled_tok = false;
|
||||||
const int n_probs = slot.params.sampling.n_probs;
|
result.probs.reserve(n_probs);
|
||||||
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
|
// set probability for sampled token
|
||||||
|
if (cur[i].id == result.tok) {
|
||||||
|
found_sampled_tok = true;
|
||||||
|
result.prob = cur[i].p;
|
||||||
|
}
|
||||||
|
// set probability for top n_probs tokens
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur[i].id,
|
cur[i].id,
|
||||||
common_detokenize(ctx, {cur[i].id}, special),
|
common_detokenize(ctx, {cur[i].id}, special),
|
||||||
cur[i].p
|
cur[i].p
|
||||||
});
|
});
|
||||||
|
// break if we have all the necessary data
|
||||||
|
if (result.probs.size() == n_probs && found_sampled_tok) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue