server : fix logprobs, make it OAI-compatible (#10783)
* server : fix logprobs, make it openai-compatible * update docs * add std::log * return pre-sampling p * sort before apply softmax * add comment * fix test * set p for sampled token * update docs * add --multi-token-probs * update docs * add `post_sampling_probs` option * update docs [no ci] * remove --multi-token-probs * "top_probs" with "post_sampling_probs" * resolve review comments * rename struct token_prob to prob_info * correct comment placement * fix setting prob for sampled token
This commit is contained in:
parent
a3c33b1dce
commit
57bb2c40cd
6 changed files with 396 additions and 107 deletions
|
@ -270,9 +270,68 @@ def test_n_probs():
|
|||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "probs" in tok
|
||||
assert len(tok["probs"]) == 10
|
||||
for prob in tok["probs"]:
|
||||
assert "prob" in prob
|
||||
assert "tok_str" in prob
|
||||
assert 0.0 <= prob["prob"] <= 1.0
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_stream():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"stream": True,
|
||||
})
|
||||
for data in res:
|
||||
if data["stop"] == False:
|
||||
assert "completion_probabilities" in data
|
||||
assert len(data["completion_probabilities"]) == 1
|
||||
for tok in data["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_post_sampling():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"post_sampling_probs": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_probs"]) == 10
|
||||
for prob in tok["top_probs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue