add more tests

This commit is contained in:
Xuan Son Nguyen 2024-11-20 12:49:18 +01:00
parent d7de41302b
commit 3249aabc0b
7 changed files with 196 additions and 18 deletions

View file

@ -180,7 +180,7 @@ jobs:
run: |
cd examples/server/tests
$env:PYTHONIOENCODING = ":replace"
pytest -v -s
pytest -v -s -x
- name: Slow tests
id: server_integration_tests_slow
@ -188,4 +188,4 @@ jobs:
run: |
cd examples/server/tests
$env:SLOW_TESTS = "1"
pytest -v -s
pytest -v -s -x

View file

@ -3,7 +3,7 @@ from utils import *
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def stop_server_after_each_test():
# do nothing before each test
yield

View file

@ -4,7 +4,7 @@ set -eu
if [ $# -lt 1 ]
then
pytest -v -s
pytest -v -s -x
else
pytest "$@"
fi

View file

@ -1,19 +1,13 @@
import pytest
from utils import *
server = ServerProcess()
server = ServerPreset.tinyllamas()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.n_ctx = 256
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
server = ServerPreset.tinyllamas()
def test_server_start_simple():

View file

@ -0,0 +1,70 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllamas()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllamas()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"stream": True,
})
content = ""
for data in res:
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert match_regex(re_content, content)
else:
content += data["content"]
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.completions.create(
model="gpt-3.5-turbo-instruct",
prompt="I believe the meaning of life is",
n=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "length"
assert match_regex("(going|bed)+", res.choices[0].text)

View file

@ -0,0 +1,59 @@
import pytest
from utils import *
server = ServerPreset.tinyllamas()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllamas()
def test_tokenize_detokenize():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
resTok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert resTok.status_code == 200
assert len(resTok.body["tokens"]) > 5
# detokenize
resDetok = server.make_request("POST", "/detokenize", data={
"tokens": resTok.body["tokens"],
})
assert resDetok.status_code == 200
assert resDetok.body["content"].strip() == content
def test_tokenize_with_bos():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
bosId = 1
resTok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert resTok.status_code == 200
assert resTok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
global server
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
resTok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert resTok.status_code == 200
for token in resTok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token
assert len(token["piece"]) > 0

View file

@ -5,6 +5,8 @@
import subprocess
import os
import re
import json
import sys
import threading
import requests
@ -19,6 +21,7 @@ from typing import (
Sequence,
Set,
)
from re import RegexFlag
class ServerResponse:
@ -34,6 +37,9 @@ class ServerProcess:
server_host: str = "127.0.0.1"
model_hf_repo: str = "ggml-org/models"
model_hf_file: str = "tinyllamas/stories260K.gguf"
model_alias: str = "tinyllama-2"
temperature: float = 0.8
seed: int = 42
# custom options
model_alias: str | None = None
@ -48,7 +54,6 @@ class ServerProcess:
n_ga_w: int | None = None
n_predict: int | None = None
n_prompts: int | None = 0
n_server_predict: int | None = None
slot_save_path: str | None = None
id_slot: int | None = None
cache_prompt: bool | None = None
@ -58,12 +63,9 @@ class ServerProcess:
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
seed: int | None = None
draft: int | None = None
server_seed: int | None = None
user_api_key: str | None = None
response_format: str | None = None
temperature: float | None = None
lora_file: str | None = None
disable_ctx_shift: int | None = False
@ -86,6 +88,10 @@ class ServerProcess:
self.server_host,
"--port",
self.server_port,
"--temp",
self.temperature,
"--seed",
self.seed,
]
if self.model_file:
server_args.extend(["--model", self.model_file])
@ -119,8 +125,8 @@ class ServerProcess:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.n_server_predict:
server_args.extend(["--n-predict", self.n_server_predict])
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.server_api_key:
@ -216,12 +222,52 @@ class ServerProcess:
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json()
print("Response from server", result.body)
return result
def make_stream_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
) -> Iterator[dict]:
url = f"http://{self.server_host}:{self.server_port}{path}"
headers = {}
if self.user_api_key:
headers["Authorization"] = f"Bearer {self.user_api_key}"
if method == "POST":
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
break
elif line.startswith('data: '):
data = json.loads(line[6:])
print("Partial response from server", data)
yield data
server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def tinyllamas() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.model_alias = "tinyllama-2"
server.n_ctx = 256
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
server.seed = 42
return server
def multiple_post_requests(
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
) -> Sequence[ServerResponse]:
@ -248,3 +294,12 @@ def multiple_post_requests(
thread.join()
return results
def match_regex(regex: str, text: str) -> bool:
return (
re.compile(
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
).search(text)
is not None
)