From 5b966df17736037890cb481175d853a5418c2d2b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Dec 2024 14:39:36 +0100 Subject: [PATCH] fix setting prob for sampled token --- examples/server/server.cpp | 45 ++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9b338e1d9..fa3682a92 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -475,14 +475,14 @@ struct completion_token_output { {"id", p.tok}, {"token", txt}, {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, { post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob) }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, }); } return out; @@ -1956,52 +1956,49 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; - int n_vocab = llama_n_vocab(llama_get_model(ctx)); + size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; - bool found_sampled_tok = false; - result.probs.reserve(max_probs); + // set probability for sampled token for (size_t i = 0; i < max_probs; i++) { - // set probability for sampled token if (cur_p->data[i].id == result.tok) { - found_sampled_tok = true; result.prob = cur_p->data[i].p; + break; } - // set probability for top n_probs tokens + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { result.probs.push_back({ cur_p->data[i].id, common_detokenize(ctx, {cur_p->data[i].id}, special), cur_p->data[i].p }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; - } } } else { // TODO: optimize this with min-p optimization std::vector cur = get_token_probabilities(ctx, idx); - bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { // set probability for sampled token if (cur[i].id == result.tok) { - found_sampled_tok = true; result.prob = cur[i].p; + break; } - // set probability for top n_probs tokens + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { result.probs.push_back({ cur[i].id, common_detokenize(ctx, {cur[i].id}, special), cur[i].p }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; - } } } } @@ -2894,7 +2891,7 @@ struct server_context { completion_token_output result; result.tok = id; result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); - result.prob = 1.0f; // set later + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.params.sampling.n_probs > 0) { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);