server : replace behave with pytest

This commit is contained in:
Xuan Son Nguyen 2024-11-19 23:29:46 +01:00
parent 42ae10bbcd
commit 3acaf58e38
6 changed files with 298 additions and 4 deletions

View file

@ -1 +1,2 @@
.venv .venv
tmp

View file

@ -0,0 +1,15 @@
import pytest
from utils import *
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
@pytest.fixture(scope="session", autouse=True)
def stop_server_after_each_test():
# do nothing before each test
yield
# stop all servers after each test
instances = set(
server_instances
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()

View file

@ -1,5 +1,5 @@
aiohttp~=3.9.3 aiohttp~=3.9.3
behave~=1.2.6 pytest~=8.3.3
huggingface_hub~=0.23.2 huggingface_hub~=0.23.2
numpy~=1.26.4 numpy~=1.26.4
openai~=1.30.3 openai~=1.30.3

View file

@ -4,8 +4,7 @@ set -eu
if [ $# -lt 1 ] if [ $# -lt 1 ]
then then
# Start @llama.cpp scenario pytest -v -s
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
else else
behave "$@" pytest "$@"
fi fi

View file

@ -0,0 +1,31 @@
import pytest
from utils import *
server = ServerProcess()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.n_ctx = 256
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
def test_server_start_simple():
global server
server.start()
res = server.make_request("GET", "/health")
assert res.status_code == 200
def test_server_props():
global server
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["total_slots"] == server.n_slots

View file

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import subprocess
import os
import sys
import threading
import requests
import time
from typing import (
Any,
Callable,
ContextManager,
Iterable,
Iterator,
Literal,
Sequence,
Set,
)
class ServerResponse:
headers: dict
status_code: int
body: dict
class ServerProcess:
# default options
debug: bool = False
server_port: int = 8080
server_host: str = "127.0.0.1"
model_hf_repo: str = "ggml-org/models"
model_hf_file: str = "tinyllamas/stories260K.gguf"
# custom options
model_alias: str | None = None
model_url: str | None = None
model_file: str | None = None
n_threads: int | None = None
n_gpu_layer: str | None = None
n_batch: int | None = None
n_ubatch: int | None = None
n_ctx: int | None = None
n_ga: int | None = None
n_ga_w: int | None = None
n_predict: int | None = None
n_prompts: int | None = 0
n_server_predict: int | None = None
slot_save_path: str | None = None
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
server_api_key: str | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
seed: int | None = None
draft: int | None = None
server_seed: int | None = None
user_api_key: str | None = None
response_format: str | None = None
temperature: float | None = None
lora_file: str | None = None
disable_ctx_shift: int | None = False
# session variables
process: subprocess.Popen | None = None
def __init__(self):
pass
def start(self, timeout_seconds: int = 10) -> None:
if "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt":
server_path = "../../../build/bin/Release/llama-server.exe"
else:
server_path = "../../../build/bin/llama-server"
server_args = [
"--slots", # requires to get slot status via /slots endpoint
"--host",
self.server_host,
"--port",
self.server_port,
]
if self.model_file:
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
server_args.extend(["--ubatch-size", self.n_ubatch])
if self.n_threads:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.draft is not None:
server_args.extend(["--draft", self.draft])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
server_args.append("--embedding")
if self.server_reranking:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.n_ctx:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.n_server_predict:
server_args.extend(["--n-predict", self.n_server_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.server_api_key:
server_args.extend(["--api-key", self.server_api_key])
if self.n_ga:
server_args.extend(["--grp-attn-n", self.n_ga])
if self.n_ga_w:
server_args.extend(["--grp-attn-w", self.n_ga_w])
if self.debug:
server_args.append("--verbose")
if self.lora_file:
server_args.extend(["--lora", self.lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
flags = 0
if "nt" == os.name:
flags |= subprocess.DETACHED_PROCESS
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
flags |= subprocess.CREATE_NO_WINDOW
self.process = subprocess.Popen(
[str(arg) for arg in [server_path, *server_args]],
creationflags=flags,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={**os.environ, "LLAMA_CACHE": "tmp"},
)
server_instances.add(self)
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b""):
print(line.decode("utf-8"), end="", file=out_stream)
thread_stdout = threading.Thread(
target=server_log, args=(self.process.stdout, sys.stdout), daemon=True
)
thread_stdout.start()
thread_stderr = threading.Thread(
target=server_log, args=(self.process.stderr, sys.stderr), daemon=True
)
thread_stderr.start()
print(f"server pid={self.process.pid}, behave pid={os.getpid()}")
# wait for server to start
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = self.make_request("GET", "/slots")
if response.status_code == 200:
self.ready = True
return # server is ready
except Exception as e:
pass
print(f"Waiting for server to start...")
time.sleep(0.5)
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
self.process = None
def make_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
headers = {}
if self.user_api_key:
headers["Authorization"] = f"Bearer {self.user_api_key}"
if self.response_format:
headers["Accept"] = self.response_format
if method == "GET":
response = requests.get(url, headers=headers)
elif method == "POST":
response = requests.post(url, headers=headers, json=data)
elif method == "OPTIONS":
response = requests.options(url, headers=headers)
else:
raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json()
return result
server_instances: Set[ServerProcess] = set()
def multiple_post_requests(
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
) -> Sequence[ServerResponse]:
def worker(data_chunk):
try:
return server.make_request("POST", path, data=data_chunk, headers=headers)
except Exception as e:
print(f"Error occurred: {e}", file=sys.stderr)
os._exit(1) # terminate main thread
threads = []
results = []
def thread_target(data_chunk):
result = worker(data_chunk)
results.append(result)
for chunk in data:
thread = threading.Thread(target=thread_target, args=(chunk,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
return results