diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5a3f5d889..d50ae2dc7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1886,25 +1886,17 @@ struct server_context { return slot.has_next_token; // continue } - void populate_token_probs(const server_slot & slot, completion_token_output & result) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - const size_t max_probs = cur_p->size; + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { + std::vector cur = get_token_probabilities(ctx, idx); + int n_vocab = llama_n_vocab(llama_get_model(ctx)); - // set prob for the sampled token - for (size_t i = 0; i < max_probs; ++i) { - if (result.tok == cur_p->data[i].id) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probs for the top n tokens - for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) { - auto tok_id = cur_p->data[i].id; + // only take at most n_probs tokens + const int n_probs = slot.params.sampling.n_probs; + for (int i = 0; i < std::min(n_probs, n_vocab); i++) { result.probs.push_back({ - tok_id, - tokens_to_output_formatted_string(ctx, tok_id), - cur_p->data[i].p, + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p }); } } @@ -2758,7 +2750,9 @@ struct server_context { continue; // continue loop of slots } - llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.i_batch = -1; @@ -2782,7 +2776,7 @@ struct server_context { result.prob = 1.0f; // set later if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result); + populate_token_probs(slot, result, params_base.special, tok_idx); } if (!process_token(result, slot)) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3750cf758..60c7656d3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector & logit_bias) static std::string safe_json_to_str(json data) { return data.dump(-1, ' ', false, json::error_handler_t::replace); } + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + // sort tokens by probability + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.p > b.p; + }); + + return cur; +}