server : fix logprobs, make it OAI-compatible (#10783)
* server : fix logprobs, make it openai-compatible * update docs * add std::log * return pre-sampling p * sort before apply softmax * add comment * fix test * set p for sampled token * update docs * add --multi-token-probs * update docs * add `post_sampling_probs` option * update docs [no ci] * remove --multi-token-probs * "top_probs" with "post_sampling_probs" * resolve review comments * rename struct token_prob to prob_info * correct comment placement * fix setting prob for sampled token
This commit is contained in:
parent
a3c33b1dce
commit
57bb2c40cd
6 changed files with 396 additions and 107 deletions
|
@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
|
|||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
|
|||
assert "predicted_per_second" in data["timings"]
|
||||
assert "predicted_n" in data["timings"]
|
||||
assert data["timings"]["predicted_n"] <= 10
|
||||
|
||||
|
||||
def test_logprobs():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
)
|
||||
output_text = res.choices[0].message.content
|
||||
aggregated_text = ''
|
||||
assert res.choices[0].logprobs is not None
|
||||
assert res.choices[0].logprobs.content is not None
|
||||
for token in res.choices[0].logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
||||
|
||||
def test_logprobs_stream():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
aggregated_text = ''
|
||||
for data in res:
|
||||
choice = data.choices[0]
|
||||
if choice.finish_reason is None:
|
||||
if choice.delta.content:
|
||||
output_text += choice.delta.content
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
for token in choice.logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert token.top_logprobs is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue