add std::log

This commit is contained in:
Xuan Son Nguyen 2024-12-12 11:16:12 +01:00
parent 7828013689
commit 01afafef93
3 changed files with 13 additions and 13 deletions

View file

@ -342,11 +342,6 @@ struct server_task {
}
}
if (params.sampling.n_probs > 0 && params.cache_prompt) {
SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs);
params.cache_prompt = false;
}
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name);
@ -439,7 +434,7 @@ struct completion_token_output {
{"id", p.tok},
{"token", tok_str},
{"bytes", str_to_bytes(p.tok_str)},
{"logprob", p.prob},
{"logprob", logarithm(p.prob)},
});
}
return probs_for_token;
@ -453,7 +448,7 @@ struct completion_token_output {
out.push_back(json {
{"id", it.tok},
{"token", tok_str},
{"logprob", it.prob},
{"logprob", logarithm(it.prob)},
{"bytes", str_to_bytes(it.text_to_send)},
{"top_logprobs", it.to_json()},
});
@ -461,6 +456,11 @@ struct completion_token_output {
return out;
}
static float logarithm(float x) {
// nlohmann::json converts -inf to null, so we need to prevent that
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
}
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
std::vector<unsigned char> bytes;
for (unsigned char c : str) {

View file

@ -185,7 +185,7 @@ def test_logprobs():
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert 0.0 <= token.logprob <= 1.0
assert token.logprob <= 0.0
assert token.bytes is not None and len(token.bytes) > 0
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
@ -218,7 +218,7 @@ def test_logprobs_stream():
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert 0.0 <= token.logprob <= 1.0
assert token.logprob <= 0.0
assert token.bytes is not None and len(token.bytes) > 0
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0

View file

@ -262,13 +262,13 @@ def test_n_probs():
for tok in res.body["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and len(tok["bytes"]) > 0
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and len(prob["bytes"]) > 0
@ -289,11 +289,11 @@ def test_n_probs_stream():
for tok in data["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and len(tok["bytes"]) > 0
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and len(prob["bytes"]) > 0