From 01afafef93ad32b8be48987a0b86649bf176e39f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 11:16:12 +0100 Subject: [PATCH] add std::log --- examples/server/server.cpp | 14 +++++++------- examples/server/tests/unit/test_chat_completion.py | 4 ++-- examples/server/tests/unit/test_completion.py | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2c94318b4..5a3f5d889 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,11 +342,6 @@ struct server_task { } } - if (params.sampling.n_probs > 0 && params.cache_prompt) { - SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs); - params.cache_prompt = false; - } - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; params.oaicompat_model = json_value(data, "model", model_name); @@ -439,7 +434,7 @@ struct completion_token_output { {"id", p.tok}, {"token", tok_str}, {"bytes", str_to_bytes(p.tok_str)}, - {"logprob", p.prob}, + {"logprob", logarithm(p.prob)}, }); } return probs_for_token; @@ -453,7 +448,7 @@ struct completion_token_output { out.push_back(json { {"id", it.tok}, {"token", tok_str}, - {"logprob", it.prob}, + {"logprob", logarithm(it.prob)}, {"bytes", str_to_bytes(it.text_to_send)}, {"top_logprobs", it.to_json()}, }); @@ -461,6 +456,11 @@ struct completion_token_output { return out; } + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + static std::vector str_to_bytes(const std::string & str) { std::vector bytes; for (unsigned char c : str) { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 299472fa4..ce94398d6 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -185,7 +185,7 @@ def test_logprobs(): assert res.choices[0].logprobs.content is not None for token in res.choices[0].logprobs.content: aggregated_text += token.token - assert 0.0 <= token.logprob <= 1.0 + assert token.logprob <= 0.0 assert token.bytes is not None and len(token.bytes) > 0 assert len(token.top_logprobs) > 0 assert aggregated_text == output_text @@ -218,7 +218,7 @@ def test_logprobs_stream(): assert choice.logprobs.content is not None for token in choice.logprobs.content: aggregated_text += token.token - assert 0.0 <= token.logprob <= 1.0 + assert token.logprob <= 0.0 assert token.bytes is not None and len(token.bytes) > 0 assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 4c89ee3ee..9e91c5da2 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -262,13 +262,13 @@ def test_n_probs(): for tok in res.body["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "logprob" in tok and tok["logprob"] <= 0.0 assert "bytes" in tok and len(tok["bytes"]) > 0 assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and len(prob["bytes"]) > 0 @@ -289,11 +289,11 @@ def test_n_probs_stream(): for tok in data["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "logprob" in tok and tok["logprob"] <= 0.0 assert "bytes" in tok and len(tok["bytes"]) > 0 assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and len(prob["bytes"]) > 0