diff --git a/examples/server/README.md b/examples/server/README.md index 647fa49ab..6d6465692 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -496,8 +496,8 @@ These words will not be included in the completion, so make sure to add them to }, ``` Please note that if `post_sampling_probs` is set to `true`: - - `logprob` will be replace with `prob`, with the value between 0.0 and 1.0 - - `top_logprobs` will be replace with `top_probs`. Each element inside contains: + - `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0 + - `top_logprobs` will be replaced with `top_probs`. Each element contains: - `id`: token ID - `token`: token in string - `bytes`: token in bytes diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a5ac8db76..dfa38db06 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -443,7 +443,7 @@ struct completion_token_output { std::string text_to_send; struct token_prob { llama_token tok; - std::string tok_str; + std::string txt; float prob; }; std::vector probs; @@ -451,12 +451,12 @@ struct completion_token_output { json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); for (const auto & p : probs) { - std::string tok_str(p.tok_str); - tok_str.resize(validate_utf8(tok_str)); + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); probs_for_token.push_back(json { {"id", p.tok}, - {"token", tok_str}, - {"bytes", str_to_bytes(p.tok_str)}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, { post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob) @@ -468,20 +468,20 @@ struct completion_token_output { static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { json out = json::array(); - for (const auto & it : probs) { - std::string tok_str(it.text_to_send); - tok_str.resize(validate_utf8(tok_str)); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); out.push_back(json { - {"id", it.tok}, - {"token", tok_str}, - {"bytes", str_to_bytes(it.text_to_send)}, + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, { post_sampling_probs ? "top_probs" : "top_logprobs", - it.to_json(post_sampling_probs) + p.to_json(post_sampling_probs) }, { post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? it.prob : logarithm(it.prob) + post_sampling_probs ? p.prob : logarithm(p.prob) }, }); } @@ -1958,28 +1958,7 @@ struct server_context { size_t n_probs = slot.params.sampling.n_probs; int n_vocab = llama_n_vocab(llama_get_model(ctx)); if (post_sampling) { - std::vector cur = get_token_probabilities(ctx, idx); - - bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - found_sampled_tok = true; - result.prob = cur[i].p; - } - // set probability for top n_probs tokens - result.probs.push_back({ - cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), - cur[i].p - }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; - } - } - } else { + // TODO: optimize this with min-p optimization const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; @@ -2002,6 +1981,28 @@ struct server_context { break; } } + } else { + std::vector cur = get_token_probabilities(ctx, idx); + + bool found_sampled_tok = false; + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur[i].p; + } + // set probability for top n_probs tokens + result.probs.push_back({ + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } + } } } diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 24342b3bb..b88d45f18 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -325,7 +325,7 @@ def test_n_probs_post_sampling(): for tok in res.body["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0 + assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 assert "bytes" in tok and type(tok["bytes"]) == list assert len(tok["top_probs"]) == 10 for prob in tok["top_probs"]: @@ -333,3 +333,5 @@ def test_n_probs_post_sampling(): assert "token" in prob and type(prob["token"]) == str assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list + # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs + assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])