From ed7f2d5756933edff163691b2a455e8847ac7651 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 14:05:38 +0100 Subject: [PATCH] set p for sampled token --- examples/server/server.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1a296aa00..95bd531b3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1889,15 +1889,26 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { std::vector cur = get_token_probabilities(ctx, idx); int n_vocab = llama_n_vocab(llama_get_model(ctx)); + size_t n_probs = slot.params.sampling.n_probs; - // only take at most n_probs tokens - const int n_probs = slot.params.sampling.n_probs; - for (int i = 0; i < std::min(n_probs, n_vocab); i++) { + bool found_sampled_tok = false; + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur[i].p; + } + // set probability for top n_probs tokens result.probs.push_back({ cur[i].id, common_detokenize(ctx, {cur[i].id}, special), cur[i].p }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } }