server : (refactoring) do not rely on JSON internally (#10643)
* server : (refactoring) reduce usage of json internally * move all response types to struct * wip [no ci] * many fixes * add virtual function * fix index * minor style fix * add std::move * refactor handle_completions_generic * add virtual functions * remove server.hpp * clarify server_sent_event RFC specs * apply review comments * fix model_alias and completion_probabilities * small clean up * remove virtual for to_json_oai_compat() * naming oai_compat --> oaicompat * fix unwanted recursive call * update docs
This commit is contained in:
parent
7736837d62
commit
6c5bc0625f
8 changed files with 985 additions and 697 deletions
|
@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
|
|||
DEBUG=1 ./tests.sh -s -v -x
|
||||
```
|
||||
|
||||
Hint: You can compile and run test in single command, useful for local developement:
|
||||
|
||||
```shell
|
||||
cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
|
||||
```
|
||||
|
||||
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
# make sure we are in the right directory
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
|
|
|
@ -12,13 +12,13 @@ def create_server():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
@ -30,29 +30,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["model"] == model if model is not None else server.model_alias
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
if truncated:
|
||||
assert choice["finish_reason"] == "length"
|
||||
else:
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
|
@ -63,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
|
|||
content = ""
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
assert "content" not in choice["delta"]
|
||||
assert match_regex(re_content, content)
|
||||
# FIXME: not sure why this is incorrect in stream mode
|
||||
# if truncated:
|
||||
# assert choice["finish_reason"] == "length"
|
||||
# else:
|
||||
# assert choice["finish_reason"] == "stop"
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
|
@ -93,7 +88,7 @@ def test_chat_completion_with_openai_library():
|
|||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "stop"
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
||||
|
|
|
@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||
content += data["content"]
|
||||
|
||||
|
||||
def test_completion_stream_vs_non_stream():
|
||||
global server
|
||||
server.start()
|
||||
res_stream = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"stream": True,
|
||||
})
|
||||
res_non_stream = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
})
|
||||
content_stream = ""
|
||||
for data in res_stream:
|
||||
content_stream += data["content"]
|
||||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
|
@ -221,3 +239,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||
assert len(res.body["content"]) > 10
|
||||
# FIXME: the result is not deterministic when using other slot than slot 0
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
def test_n_probs():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "probs" in tok
|
||||
assert len(tok["probs"]) == 10
|
||||
for prob in tok["probs"]:
|
||||
assert "prob" in prob
|
||||
assert "tok_str" in prob
|
||||
assert 0.0 <= prob["prob"] <= 1.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue