From ecadd37c63381766875adacb687b3856f27fa913 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 14:11:04 +0100 Subject: [PATCH] add `post_sampling_probs` option --- examples/server/README.md | 84 +++++++------ examples/server/server.cpp | 114 ++++++++++++------ examples/server/tests/unit/test_completion.py | 27 +++++ examples/server/tests/unit/test_embedding.py | 3 + 4 files changed, 150 insertions(+), 78 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index fa6df1ce4..e43845135 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -449,52 +449,56 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` +`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. - `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements: - -```json -{ - "content": "", - "tokens": [ generated token ids if requested ], - ... - "probs": [ - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - "top_logprobs": [ - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - }, - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - }, - ... - ] - }, - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - "top_logprobs": [ - ... - ] - }, + ```json + { + "content": "", + "tokens": [ generated token ids if requested ], ... - ] -}, -``` + "probs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + ... + ] + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + ... + ] + }, + ... + ] + }, + ``` + Please note that if `post_sampling_probs` is set to `true`: + - `logprob` will be replace with `prob`, with the value between 0.0 and 1.0 + - Returned number of probabilities may be less than `n_probs` - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 854dbda1c..93196adcd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -93,6 +93,7 @@ struct slot_params { std::vector antiprompt; bool timings_per_token = false; + bool post_sampling_probs = false; bool ignore_eos = false; struct common_params_sampling sampling; @@ -151,6 +152,7 @@ struct slot_params { {"speculative.n_min", speculative.n_min}, {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, }; } }; @@ -231,6 +233,7 @@ struct server_task { params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); @@ -449,7 +452,7 @@ struct completion_token_output { }; std::vector probs; - json to_json() const { + json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); for (const auto & p : probs) { std::string tok_str(p.tok_str); @@ -458,13 +461,16 @@ struct completion_token_output { {"id", p.tok}, {"token", tok_str}, {"bytes", str_to_bytes(p.tok_str)}, - {"logprob", logarithm(p.prob)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, }); } return probs_for_token; } - static json probs_vector_to_json(const std::vector & probs) { + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { json out = json::array(); for (const auto & it : probs) { std::string tok_str(it.text_to_send); @@ -472,9 +478,12 @@ struct completion_token_output { out.push_back(json { {"id", it.tok}, {"token", tok_str}, - {"logprob", logarithm(it.prob)}, {"bytes", str_to_bytes(it.text_to_send)}, - {"top_logprobs", it.to_json()}, + {"top_logprobs", it.to_json(post_sampling_probs)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? it.prob : logarithm(it.prob) + }, }); } return out; @@ -512,6 +521,7 @@ struct server_task_result_cmpl_final : server_task_result { std::string stopping_word; stop_type stop = STOP_TYPE_NONE; + bool post_sampling_probs; std::vector probs_output; slot_params generation_params; @@ -557,7 +567,7 @@ struct server_task_result_cmpl_final : server_task_result { {"timings", timings.to_json()}, }; if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); } return res; } @@ -579,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!stream && probs_output.size() > 0) { choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output)}, + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, }; } @@ -652,6 +662,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; + bool post_sampling_probs; completion_token_output prob_output; result_timings timings; @@ -690,7 +701,7 @@ struct server_task_result_cmpl_partial : server_task_result { res.push_back({"timings", timings.to_json()}); } if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}); + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); } return res; } @@ -746,7 +757,7 @@ struct server_task_result_cmpl_partial : server_task_result { if (prob_output.probs.size() > 0) { choices[0]["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output})}, + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, }; } @@ -1944,28 +1955,53 @@ struct server_context { return slot.has_next_token; // continue } - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { - std::vector cur = get_token_probabilities(ctx, idx); - int n_vocab = llama_n_vocab(llama_get_model(ctx)); + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; + int n_vocab = llama_n_vocab(llama_get_model(ctx)); + if (post_sampling) { + std::vector cur = get_token_probabilities(ctx, idx); - bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - found_sampled_tok = true; - result.prob = cur[i].p; + bool found_sampled_tok = false; + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur[i].p; + } + // set probability for top n_probs tokens + result.probs.push_back({ + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } - // set probability for top n_probs tokens - result.probs.push_back({ - cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), - cur[i].p - }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; + } else { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + bool found_sampled_tok = false; + result.probs.reserve(max_probs); + for (size_t i = 0; i < max_probs; i++) { + // set probability for sampled token + if (cur_p->data[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur_p->data[i].p; + } + // set probability for top n_probs tokens + result.probs.push_back({ + cur_p->data[i].id, + common_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } } } @@ -1997,8 +2033,9 @@ struct server_context { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; res->oaicompat = slot.params.oaicompat; @@ -2030,13 +2067,14 @@ struct server_context { res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->n_tokens_cached = slot.n_past; - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; res->stream = slot.params.stream; @@ -2859,7 +2897,7 @@ struct server_context { result.prob = 1.0f; // set later if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, params_base.special, tok_idx); + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index c26f982d8..78aaed052 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -309,3 +309,30 @@ def test_n_probs_stream(): assert "token" in prob and type(prob["token"]) == str assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_post_sampling(): + global server + server.multi_token_probs = True + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "post_sampling_probs": True, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + assert "bytes" in prob and type(prob["bytes"]) == list diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index e32d74582..43e372fc7 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -50,6 +50,8 @@ def test_embedding_multiple(): @pytest.mark.parametrize( "input,is_multi_prompt", [ + # do not crash on empty input + ("", False), # single prompt ("string", False), ([12, 34, 56], False), @@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai(): # /v1/embeddings does not support pooling type 'none' assert res.status_code == 400 + assert "error" in res.body def test_embedding_openai_library_single():