From 3249aabc0b5f718b77748e1db7afc5d6800f4661 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 20 Nov 2024 12:49:18 +0100 Subject: [PATCH] add more tests --- .github/workflows/server.yml | 4 +- examples/server/tests/conftest.py | 2 +- examples/server/tests/tests.sh | 2 +- examples/server/tests/unit/test_basic.py | 10 +-- examples/server/tests/unit/test_completion.py | 70 +++++++++++++++++++ examples/server/tests/unit/test_tokenize.py | 59 ++++++++++++++++ examples/server/tests/utils.py | 67 ++++++++++++++++-- 7 files changed, 196 insertions(+), 18 deletions(-) create mode 100644 examples/server/tests/unit/test_completion.py create mode 100644 examples/server/tests/unit/test_tokenize.py diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 5860a687e..9e1538831 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -180,7 +180,7 @@ jobs: run: | cd examples/server/tests $env:PYTHONIOENCODING = ":replace" - pytest -v -s + pytest -v -s -x - name: Slow tests id: server_integration_tests_slow @@ -188,4 +188,4 @@ jobs: run: | cd examples/server/tests $env:SLOW_TESTS = "1" - pytest -v -s + pytest -v -s -x diff --git a/examples/server/tests/conftest.py b/examples/server/tests/conftest.py index 2c2c8e37b..017d1bb84 100644 --- a/examples/server/tests/conftest.py +++ b/examples/server/tests/conftest.py @@ -3,7 +3,7 @@ from utils import * # ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def stop_server_after_each_test(): # do nothing before each test yield diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 7d285c73f..130b09f99 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -4,7 +4,7 @@ set -eu if [ $# -lt 1 ] then - pytest -v -s + pytest -v -s -x else pytest "$@" fi diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index 3cf556e64..21b6eda3a 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -1,19 +1,13 @@ import pytest from utils import * -server = ServerProcess() +server = ServerPreset.tinyllamas() @pytest.fixture(scope="module", autouse=True) def create_server(): global server - server = ServerProcess() - server.model_hf_repo = "ggml-org/models" - server.model_hf_file = "tinyllamas/stories260K.gguf" - server.n_ctx = 256 - server.n_batch = 32 - server.n_slots = 2 - server.n_predict = 64 + server = ServerPreset.tinyllamas() def test_server_start_simple(): diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py new file mode 100644 index 000000000..1503fa1cf --- /dev/null +++ b/examples/server/tests/unit/test_completion.py @@ -0,0 +1,70 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllamas() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllamas() + + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == n_prompt + assert res.body["timings"]["predicted_n"] == n_predicted + assert res.body["truncated"] == truncated + assert match_regex(re_content, res.body["content"]) + + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "stream": True, + }) + content = "" + for data in res: + if data["stop"]: + assert data["timings"]["prompt_n"] == n_prompt + assert data["timings"]["predicted_n"] == n_predicted + assert data["truncated"] == truncated + assert match_regex(re_content, content) + else: + content += data["content"] + + +# FIXME: This test is not working because /completions endpoint is not OAI-compatible +@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now") +def test_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.completions.create( + model="gpt-3.5-turbo-instruct", + prompt="I believe the meaning of life is", + n=8, + seed=42, + temperature=0.8, + ) + print(res) + assert res.choices[0].finish_reason == "length" + assert match_regex("(going|bed)+", res.choices[0].text) diff --git a/examples/server/tests/unit/test_tokenize.py b/examples/server/tests/unit/test_tokenize.py new file mode 100644 index 000000000..c7cd73860 --- /dev/null +++ b/examples/server/tests/unit/test_tokenize.py @@ -0,0 +1,59 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllamas() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllamas() + + +def test_tokenize_detokenize(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + resTok = server.make_request("POST", "/tokenize", data={ + "content": content + }) + assert resTok.status_code == 200 + assert len(resTok.body["tokens"]) > 5 + # detokenize + resDetok = server.make_request("POST", "/detokenize", data={ + "tokens": resTok.body["tokens"], + }) + assert resDetok.status_code == 200 + assert resDetok.body["content"].strip() == content + + +def test_tokenize_with_bos(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + bosId = 1 + resTok = server.make_request("POST", "/tokenize", data={ + "content": content, + "add_special": True, + }) + assert resTok.status_code == 200 + assert resTok.body["tokens"][0] == bosId + + +def test_tokenize_with_pieces(): + global server + server.start() + # tokenize + content = "This is a test string with unicode 媽 and emoji 🤗" + resTok = server.make_request("POST", "/tokenize", data={ + "content": content, + "with_pieces": True, + }) + assert resTok.status_code == 200 + for token in resTok.body["tokens"]: + assert "id" in token + assert token["id"] > 0 + assert "piece" in token + assert len(token["piece"]) > 0 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 829658722..30620901d 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -5,6 +5,8 @@ import subprocess import os +import re +import json import sys import threading import requests @@ -19,6 +21,7 @@ from typing import ( Sequence, Set, ) +from re import RegexFlag class ServerResponse: @@ -34,6 +37,9 @@ class ServerProcess: server_host: str = "127.0.0.1" model_hf_repo: str = "ggml-org/models" model_hf_file: str = "tinyllamas/stories260K.gguf" + model_alias: str = "tinyllama-2" + temperature: float = 0.8 + seed: int = 42 # custom options model_alias: str | None = None @@ -48,7 +54,6 @@ class ServerProcess: n_ga_w: int | None = None n_predict: int | None = None n_prompts: int | None = 0 - n_server_predict: int | None = None slot_save_path: str | None = None id_slot: int | None = None cache_prompt: bool | None = None @@ -58,12 +63,9 @@ class ServerProcess: server_embeddings: bool | None = False server_reranking: bool | None = False server_metrics: bool | None = False - seed: int | None = None draft: int | None = None - server_seed: int | None = None user_api_key: str | None = None response_format: str | None = None - temperature: float | None = None lora_file: str | None = None disable_ctx_shift: int | None = False @@ -86,6 +88,10 @@ class ServerProcess: self.server_host, "--port", self.server_port, + "--temp", + self.temperature, + "--seed", + self.seed, ] if self.model_file: server_args.extend(["--model", self.model_file]) @@ -119,8 +125,8 @@ class ServerProcess: server_args.extend(["--ctx-size", self.n_ctx]) if self.n_slots: server_args.extend(["--parallel", self.n_slots]) - if self.n_server_predict: - server_args.extend(["--n-predict", self.n_server_predict]) + if self.n_predict: + server_args.extend(["--n-predict", self.n_predict]) if self.slot_save_path: server_args.extend(["--slot-save-path", self.slot_save_path]) if self.server_api_key: @@ -216,12 +222,52 @@ class ServerProcess: result.headers = dict(response.headers) result.status_code = response.status_code result.body = response.json() + print("Response from server", result.body) return result + + def make_stream_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + ) -> Iterator[dict]: + url = f"http://{self.server_host}:{self.server_port}{path}" + headers = {} + if self.user_api_key: + headers["Authorization"] = f"Bearer {self.user_api_key}" + if method == "POST": + response = requests.post(url, headers=headers, json=data, stream=True) + else: + raise ValueError(f"Unimplemented method: {method}") + for line_bytes in response.iter_lines(): + line = line_bytes.decode("utf-8") + if '[DONE]' in line: + break + elif line.startswith('data: '): + data = json.loads(line[6:]) + print("Partial response from server", data) + yield data server_instances: Set[ServerProcess] = set() +class ServerPreset: + @staticmethod + def tinyllamas() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K.gguf" + server.model_alias = "tinyllama-2" + server.n_ctx = 256 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 64 + server.seed = 42 + return server + + def multiple_post_requests( server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None ) -> Sequence[ServerResponse]: @@ -248,3 +294,12 @@ def multiple_post_requests( thread.join() return results + + +def match_regex(regex: str, text: str) -> bool: + return ( + re.compile( + regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL + ).search(text) + is not None + )