add more tests
This commit is contained in:
parent
d7de41302b
commit
3249aabc0b
7 changed files with 196 additions and 18 deletions
4
.github/workflows/server.yml
vendored
4
.github/workflows/server.yml
vendored
|
@ -180,7 +180,7 @@ jobs:
|
|||
run: |
|
||||
cd examples/server/tests
|
||||
$env:PYTHONIOENCODING = ":replace"
|
||||
pytest -v -s
|
||||
pytest -v -s -x
|
||||
|
||||
- name: Slow tests
|
||||
id: server_integration_tests_slow
|
||||
|
@ -188,4 +188,4 @@ jobs:
|
|||
run: |
|
||||
cd examples/server/tests
|
||||
$env:SLOW_TESTS = "1"
|
||||
pytest -v -s
|
||||
pytest -v -s -x
|
||||
|
|
|
@ -3,7 +3,7 @@ from utils import *
|
|||
|
||||
|
||||
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def stop_server_after_each_test():
|
||||
# do nothing before each test
|
||||
yield
|
||||
|
|
|
@ -4,7 +4,7 @@ set -eu
|
|||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
pytest -v -s
|
||||
pytest -v -s -x
|
||||
else
|
||||
pytest "$@"
|
||||
fi
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerProcess()
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||
server.n_ctx = 256
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 64
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
def test_server_start_simple():
|
||||
|
|
70
examples/server/tests/unit/test_completion.py
Normal file
70
examples/server/tests/unit/test_completion.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
])
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||
assert res.body["timings"]["predicted_n"] == n_predicted
|
||||
assert res.body["truncated"] == truncated
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
])
|
||||
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
if data["stop"]:
|
||||
assert data["timings"]["prompt_n"] == n_prompt
|
||||
assert data["timings"]["predicted_n"] == n_predicted
|
||||
assert data["truncated"] == truncated
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
content += data["content"]
|
||||
|
||||
|
||||
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
|
||||
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="I believe the meaning of life is",
|
||||
n=8,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
59
examples/server/tests/unit/test_tokenize.py
Normal file
59
examples/server/tests/unit/test_tokenize.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllamas()
|
||||
|
||||
|
||||
def test_tokenize_detokenize():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
assert len(resTok.body["tokens"]) > 5
|
||||
# detokenize
|
||||
resDetok = server.make_request("POST", "/detokenize", data={
|
||||
"tokens": resTok.body["tokens"],
|
||||
})
|
||||
assert resDetok.status_code == 200
|
||||
assert resDetok.body["content"].strip() == content
|
||||
|
||||
|
||||
def test_tokenize_with_bos():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
bosId = 1
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"add_special": True,
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
assert resTok.body["tokens"][0] == bosId
|
||||
|
||||
|
||||
def test_tokenize_with_pieces():
|
||||
global server
|
||||
server.start()
|
||||
# tokenize
|
||||
content = "This is a test string with unicode 媽 and emoji 🤗"
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"with_pieces": True,
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
for token in resTok.body["tokens"]:
|
||||
assert "id" in token
|
||||
assert token["id"] > 0
|
||||
assert "piece" in token
|
||||
assert len(token["piece"]) > 0
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import requests
|
||||
|
@ -19,6 +21,7 @@ from typing import (
|
|||
Sequence,
|
||||
Set,
|
||||
)
|
||||
from re import RegexFlag
|
||||
|
||||
|
||||
class ServerResponse:
|
||||
|
@ -34,6 +37,9 @@ class ServerProcess:
|
|||
server_host: str = "127.0.0.1"
|
||||
model_hf_repo: str = "ggml-org/models"
|
||||
model_hf_file: str = "tinyllamas/stories260K.gguf"
|
||||
model_alias: str = "tinyllama-2"
|
||||
temperature: float = 0.8
|
||||
seed: int = 42
|
||||
|
||||
# custom options
|
||||
model_alias: str | None = None
|
||||
|
@ -48,7 +54,6 @@ class ServerProcess:
|
|||
n_ga_w: int | None = None
|
||||
n_predict: int | None = None
|
||||
n_prompts: int | None = 0
|
||||
n_server_predict: int | None = None
|
||||
slot_save_path: str | None = None
|
||||
id_slot: int | None = None
|
||||
cache_prompt: bool | None = None
|
||||
|
@ -58,12 +63,9 @@ class ServerProcess:
|
|||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
seed: int | None = None
|
||||
draft: int | None = None
|
||||
server_seed: int | None = None
|
||||
user_api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
temperature: float | None = None
|
||||
lora_file: str | None = None
|
||||
disable_ctx_shift: int | None = False
|
||||
|
||||
|
@ -86,6 +88,10 @@ class ServerProcess:
|
|||
self.server_host,
|
||||
"--port",
|
||||
self.server_port,
|
||||
"--temp",
|
||||
self.temperature,
|
||||
"--seed",
|
||||
self.seed,
|
||||
]
|
||||
if self.model_file:
|
||||
server_args.extend(["--model", self.model_file])
|
||||
|
@ -119,8 +125,8 @@ class ServerProcess:
|
|||
server_args.extend(["--ctx-size", self.n_ctx])
|
||||
if self.n_slots:
|
||||
server_args.extend(["--parallel", self.n_slots])
|
||||
if self.n_server_predict:
|
||||
server_args.extend(["--n-predict", self.n_server_predict])
|
||||
if self.n_predict:
|
||||
server_args.extend(["--n-predict", self.n_predict])
|
||||
if self.slot_save_path:
|
||||
server_args.extend(["--slot-save-path", self.slot_save_path])
|
||||
if self.server_api_key:
|
||||
|
@ -216,12 +222,52 @@ class ServerProcess:
|
|||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
result.body = response.json()
|
||||
print("Response from server", result.body)
|
||||
return result
|
||||
|
||||
def make_stream_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
) -> Iterator[dict]:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
headers = {}
|
||||
if self.user_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
||||
if method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
else:
|
||||
raise ValueError(f"Unimplemented method: {method}")
|
||||
for line_bytes in response.iter_lines():
|
||||
line = line_bytes.decode("utf-8")
|
||||
if '[DONE]' in line:
|
||||
break
|
||||
elif line.startswith('data: '):
|
||||
data = json.loads(line[6:])
|
||||
print("Partial response from server", data)
|
||||
yield data
|
||||
|
||||
|
||||
server_instances: Set[ServerProcess] = set()
|
||||
|
||||
|
||||
class ServerPreset:
|
||||
@staticmethod
|
||||
def tinyllamas() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.n_ctx = 256
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 64
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
|
||||
def multiple_post_requests(
|
||||
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
||||
) -> Sequence[ServerResponse]:
|
||||
|
@ -248,3 +294,12 @@ def multiple_post_requests(
|
|||
thread.join()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def match_regex(regex: str, text: str) -> bool:
|
||||
return (
|
||||
re.compile(
|
||||
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
|
||||
).search(text)
|
||||
is not None
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue