add more tests
This commit is contained in:
parent
d7de41302b
commit
3249aabc0b
7 changed files with 196 additions and 18 deletions
4
.github/workflows/server.yml
vendored
4
.github/workflows/server.yml
vendored
|
@ -180,7 +180,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cd examples/server/tests
|
cd examples/server/tests
|
||||||
$env:PYTHONIOENCODING = ":replace"
|
$env:PYTHONIOENCODING = ":replace"
|
||||||
pytest -v -s
|
pytest -v -s -x
|
||||||
|
|
||||||
- name: Slow tests
|
- name: Slow tests
|
||||||
id: server_integration_tests_slow
|
id: server_integration_tests_slow
|
||||||
|
@ -188,4 +188,4 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
cd examples/server/tests
|
cd examples/server/tests
|
||||||
$env:SLOW_TESTS = "1"
|
$env:SLOW_TESTS = "1"
|
||||||
pytest -v -s
|
pytest -v -s -x
|
||||||
|
|
|
@ -3,7 +3,7 @@ from utils import *
|
||||||
|
|
||||||
|
|
||||||
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
|
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def stop_server_after_each_test():
|
def stop_server_after_each_test():
|
||||||
# do nothing before each test
|
# do nothing before each test
|
||||||
yield
|
yield
|
||||||
|
|
|
@ -4,7 +4,7 @@ set -eu
|
||||||
|
|
||||||
if [ $# -lt 1 ]
|
if [ $# -lt 1 ]
|
||||||
then
|
then
|
||||||
pytest -v -s
|
pytest -v -s -x
|
||||||
else
|
else
|
||||||
pytest "$@"
|
pytest "$@"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -1,19 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerProcess()
|
server = ServerPreset.tinyllamas()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerProcess()
|
server = ServerPreset.tinyllamas()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
|
||||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
|
||||||
server.n_ctx = 256
|
|
||||||
server.n_batch = 32
|
|
||||||
server.n_slots = 2
|
|
||||||
server.n_predict = 64
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_start_simple():
|
def test_server_start_simple():
|
||||||
|
|
70
examples/server/tests/unit/test_completion.py
Normal file
70
examples/server/tests/unit/test_completion.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllamas()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllamas()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||||
|
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||||
|
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||||
|
])
|
||||||
|
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"prompt": prompt,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||||
|
assert res.body["timings"]["predicted_n"] == n_predicted
|
||||||
|
assert res.body["truncated"] == truncated
|
||||||
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||||
|
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||||
|
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||||
|
])
|
||||||
|
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_stream_request("POST", "/completion", data={
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": True,
|
||||||
|
})
|
||||||
|
content = ""
|
||||||
|
for data in res:
|
||||||
|
if data["stop"]:
|
||||||
|
assert data["timings"]["prompt_n"] == n_prompt
|
||||||
|
assert data["timings"]["predicted_n"] == n_predicted
|
||||||
|
assert data["truncated"] == truncated
|
||||||
|
assert match_regex(re_content, content)
|
||||||
|
else:
|
||||||
|
content += data["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
|
||||||
|
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
|
||||||
|
def test_completion_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
|
res = client.completions.create(
|
||||||
|
model="gpt-3.5-turbo-instruct",
|
||||||
|
prompt="I believe the meaning of life is",
|
||||||
|
n=8,
|
||||||
|
seed=42,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
print(res)
|
||||||
|
assert res.choices[0].finish_reason == "length"
|
||||||
|
assert match_regex("(going|bed)+", res.choices[0].text)
|
59
examples/server/tests/unit/test_tokenize.py
Normal file
59
examples/server/tests/unit/test_tokenize.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import pytest
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllamas()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllamas()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenize_detokenize():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
# tokenize
|
||||||
|
content = "What is the capital of France ?"
|
||||||
|
resTok = server.make_request("POST", "/tokenize", data={
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
assert resTok.status_code == 200
|
||||||
|
assert len(resTok.body["tokens"]) > 5
|
||||||
|
# detokenize
|
||||||
|
resDetok = server.make_request("POST", "/detokenize", data={
|
||||||
|
"tokens": resTok.body["tokens"],
|
||||||
|
})
|
||||||
|
assert resDetok.status_code == 200
|
||||||
|
assert resDetok.body["content"].strip() == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenize_with_bos():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
# tokenize
|
||||||
|
content = "What is the capital of France ?"
|
||||||
|
bosId = 1
|
||||||
|
resTok = server.make_request("POST", "/tokenize", data={
|
||||||
|
"content": content,
|
||||||
|
"add_special": True,
|
||||||
|
})
|
||||||
|
assert resTok.status_code == 200
|
||||||
|
assert resTok.body["tokens"][0] == bosId
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenize_with_pieces():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
# tokenize
|
||||||
|
content = "This is a test string with unicode 媽 and emoji 🤗"
|
||||||
|
resTok = server.make_request("POST", "/tokenize", data={
|
||||||
|
"content": content,
|
||||||
|
"with_pieces": True,
|
||||||
|
})
|
||||||
|
assert resTok.status_code == 200
|
||||||
|
for token in resTok.body["tokens"]:
|
||||||
|
assert "id" in token
|
||||||
|
assert token["id"] > 0
|
||||||
|
assert "piece" in token
|
||||||
|
assert len(token["piece"]) > 0
|
|
@ -5,6 +5,8 @@
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
|
@ -19,6 +21,7 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
)
|
)
|
||||||
|
from re import RegexFlag
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
|
@ -34,6 +37,9 @@ class ServerProcess:
|
||||||
server_host: str = "127.0.0.1"
|
server_host: str = "127.0.0.1"
|
||||||
model_hf_repo: str = "ggml-org/models"
|
model_hf_repo: str = "ggml-org/models"
|
||||||
model_hf_file: str = "tinyllamas/stories260K.gguf"
|
model_hf_file: str = "tinyllamas/stories260K.gguf"
|
||||||
|
model_alias: str = "tinyllama-2"
|
||||||
|
temperature: float = 0.8
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
# custom options
|
# custom options
|
||||||
model_alias: str | None = None
|
model_alias: str | None = None
|
||||||
|
@ -48,7 +54,6 @@ class ServerProcess:
|
||||||
n_ga_w: int | None = None
|
n_ga_w: int | None = None
|
||||||
n_predict: int | None = None
|
n_predict: int | None = None
|
||||||
n_prompts: int | None = 0
|
n_prompts: int | None = 0
|
||||||
n_server_predict: int | None = None
|
|
||||||
slot_save_path: str | None = None
|
slot_save_path: str | None = None
|
||||||
id_slot: int | None = None
|
id_slot: int | None = None
|
||||||
cache_prompt: bool | None = None
|
cache_prompt: bool | None = None
|
||||||
|
@ -58,12 +63,9 @@ class ServerProcess:
|
||||||
server_embeddings: bool | None = False
|
server_embeddings: bool | None = False
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
seed: int | None = None
|
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
server_seed: int | None = None
|
|
||||||
user_api_key: str | None = None
|
user_api_key: str | None = None
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
temperature: float | None = None
|
|
||||||
lora_file: str | None = None
|
lora_file: str | None = None
|
||||||
disable_ctx_shift: int | None = False
|
disable_ctx_shift: int | None = False
|
||||||
|
|
||||||
|
@ -86,6 +88,10 @@ class ServerProcess:
|
||||||
self.server_host,
|
self.server_host,
|
||||||
"--port",
|
"--port",
|
||||||
self.server_port,
|
self.server_port,
|
||||||
|
"--temp",
|
||||||
|
self.temperature,
|
||||||
|
"--seed",
|
||||||
|
self.seed,
|
||||||
]
|
]
|
||||||
if self.model_file:
|
if self.model_file:
|
||||||
server_args.extend(["--model", self.model_file])
|
server_args.extend(["--model", self.model_file])
|
||||||
|
@ -119,8 +125,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--ctx-size", self.n_ctx])
|
server_args.extend(["--ctx-size", self.n_ctx])
|
||||||
if self.n_slots:
|
if self.n_slots:
|
||||||
server_args.extend(["--parallel", self.n_slots])
|
server_args.extend(["--parallel", self.n_slots])
|
||||||
if self.n_server_predict:
|
if self.n_predict:
|
||||||
server_args.extend(["--n-predict", self.n_server_predict])
|
server_args.extend(["--n-predict", self.n_predict])
|
||||||
if self.slot_save_path:
|
if self.slot_save_path:
|
||||||
server_args.extend(["--slot-save-path", self.slot_save_path])
|
server_args.extend(["--slot-save-path", self.slot_save_path])
|
||||||
if self.server_api_key:
|
if self.server_api_key:
|
||||||
|
@ -216,12 +222,52 @@ class ServerProcess:
|
||||||
result.headers = dict(response.headers)
|
result.headers = dict(response.headers)
|
||||||
result.status_code = response.status_code
|
result.status_code = response.status_code
|
||||||
result.body = response.json()
|
result.body = response.json()
|
||||||
|
print("Response from server", result.body)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def make_stream_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict | None = None,
|
||||||
|
headers: dict | None = None,
|
||||||
|
) -> Iterator[dict]:
|
||||||
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||||
|
headers = {}
|
||||||
|
if self.user_api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
||||||
|
if method == "POST":
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unimplemented method: {method}")
|
||||||
|
for line_bytes in response.iter_lines():
|
||||||
|
line = line_bytes.decode("utf-8")
|
||||||
|
if '[DONE]' in line:
|
||||||
|
break
|
||||||
|
elif line.startswith('data: '):
|
||||||
|
data = json.loads(line[6:])
|
||||||
|
print("Partial response from server", data)
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
server_instances: Set[ServerProcess] = set()
|
server_instances: Set[ServerProcess] = set()
|
||||||
|
|
||||||
|
|
||||||
|
class ServerPreset:
|
||||||
|
@staticmethod
|
||||||
|
def tinyllamas() -> ServerProcess:
|
||||||
|
server = ServerProcess()
|
||||||
|
server.model_hf_repo = "ggml-org/models"
|
||||||
|
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||||
|
server.model_alias = "tinyllama-2"
|
||||||
|
server.n_ctx = 256
|
||||||
|
server.n_batch = 32
|
||||||
|
server.n_slots = 2
|
||||||
|
server.n_predict = 64
|
||||||
|
server.seed = 42
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
def multiple_post_requests(
|
def multiple_post_requests(
|
||||||
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
||||||
) -> Sequence[ServerResponse]:
|
) -> Sequence[ServerResponse]:
|
||||||
|
@ -248,3 +294,12 @@ def multiple_post_requests(
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def match_regex(regex: str, text: str) -> bool:
|
||||||
|
return (
|
||||||
|
re.compile(
|
||||||
|
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
|
||||||
|
).search(text)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue