more tests

This commit is contained in:
Xuan Son Nguyen 2024-11-20 15:00:36 +01:00
parent 3249aabc0b
commit f09a9b68e1
7 changed files with 417 additions and 36 deletions

View file

@ -1,13 +1,13 @@
import pytest
from utils import *
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
def test_server_start_simple():
@ -23,3 +23,12 @@ def test_server_props():
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["total_slots"] == server.n_slots
def test_server_models():
global server
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias

View file

@ -0,0 +1,129 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
if truncated:
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "stop"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert "error" in res.body

View file

@ -2,13 +2,13 @@ import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
@ -61,10 +61,62 @@ def test_completion_with_openai_library():
res = client.completions.create(
model="gpt-3.5-turbo-instruct",
prompt="I believe the meaning of life is",
n=8,
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "length"
assert match_regex("(going|bed)+", res.choices[0].text)
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_slots", [1, 2])
def test_different_result_different_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for seed in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_batch", [16, 32])
@pytest.mark.parametrize("temperature", [0.0, 1.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res

View file

@ -0,0 +1,98 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.bert_bge_small()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
def test_embedding_single():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-6
def test_embedding_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_openai_library_single():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
])
assert len(res.data) == 4
for d in res.data:
assert len(d.embedding) > 1
def test_embedding_error_prompt_too_long():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]
def test_same_prompt_give_same_result():
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 5
for i in range(1, len(res.body['data'])):
v0 = res.body['data'][0]['embedding']
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < 1e-6

View file

@ -0,0 +1,83 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
def test_access_public_endpoint(endpoint: str):
global server
server.start()
res = server.make_request("GET", endpoint)
assert res.status_code == 200
assert "error" not in res.body
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
def test_incorrect_api_key(api_key: str):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {api_key}" if api_key else None,
})
assert res.status_code == 401
assert "error" in res.body
assert res.body["error"]["type"] == "authentication_error"
def test_correct_api_key():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a chatbot."},
{"role": "user", "content": "What is the meaning of life?"},
],
)
assert len(res.choices) == 1
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
("localhost", "Access-Control-Allow-Origin", "localhost"),
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
("origin", "Access-Control-Allow-Credentials", "true"),
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
])
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
global server
server.start()
res = server.make_request("OPTIONS", "/completions", headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization",
})
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value

View file

@ -1,13 +1,13 @@
import pytest
from utils import *
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllamas()
server = ServerPreset.tinyllama2()
def test_tokenize_detokenize():
@ -15,17 +15,17 @@ def test_tokenize_detokenize():
server.start()
# tokenize
content = "What is the capital of France ?"
resTok = server.make_request("POST", "/tokenize", data={
res_tok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert resTok.status_code == 200
assert len(resTok.body["tokens"]) > 5
assert res_tok.status_code == 200
assert len(res_tok.body["tokens"]) > 5
# detokenize
resDetok = server.make_request("POST", "/detokenize", data={
"tokens": resTok.body["tokens"],
res_detok = server.make_request("POST", "/detokenize", data={
"tokens": res_tok.body["tokens"],
})
assert resDetok.status_code == 200
assert resDetok.body["content"].strip() == content
assert res_detok.status_code == 200
assert res_detok.body["content"].strip() == content
def test_tokenize_with_bos():
@ -34,12 +34,12 @@ def test_tokenize_with_bos():
# tokenize
content = "What is the capital of France ?"
bosId = 1
resTok = server.make_request("POST", "/tokenize", data={
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert resTok.status_code == 200
assert resTok.body["tokens"][0] == bosId
assert res_tok.status_code == 200
assert res_tok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
@ -47,12 +47,12 @@ def test_tokenize_with_pieces():
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
resTok = server.make_request("POST", "/tokenize", data={
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert resTok.status_code == 200
for token in resTok.body["tokens"]:
assert res_tok.status_code == 200
for token in res_tok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token

View file

@ -58,13 +58,12 @@ class ServerProcess:
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
server_api_key: str | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
draft: int | None = None
user_api_key: str | None = None
api_key: str | None = None
response_format: str | None = None
lora_file: str | None = None
disable_ctx_shift: int | None = False
@ -129,8 +128,6 @@ class ServerProcess:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.server_api_key:
server_args.extend(["--api-key", self.server_api_key])
if self.n_ga:
server_args.extend(["--grp-attn-n", self.n_ga])
if self.n_ga_w:
@ -141,6 +138,8 @@ class ServerProcess:
server_args.extend(["--lora", self.lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key:
server_args.extend(["--api-key", self.api_key])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
@ -180,7 +179,9 @@ class ServerProcess:
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = self.make_request("GET", "/slots")
response = self.make_request("GET", "/slots", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
})
if response.status_code == 200:
self.ready = True
return # server is ready
@ -205,15 +206,13 @@ class ServerProcess:
headers: dict | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
headers = {}
if self.user_api_key:
headers["Authorization"] = f"Bearer {self.user_api_key}"
if self.response_format:
headers["Accept"] = self.response_format
parse_body = False
if method == "GET":
response = requests.get(url, headers=headers)
parse_body = True
elif method == "POST":
response = requests.post(url, headers=headers, json=data)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers)
else:
@ -221,10 +220,10 @@ class ServerProcess:
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json()
result.body = response.json() if parse_body else None
print("Response from server", result.body)
return result
def make_stream_request(
self,
method: str,
@ -233,9 +232,6 @@ class ServerProcess:
headers: dict | None = None,
) -> Iterator[dict]:
url = f"http://{self.server_host}:{self.server_port}{path}"
headers = {}
if self.user_api_key:
headers["Authorization"] = f"Bearer {self.user_api_key}"
if method == "POST":
response = requests.post(url, headers=headers, json=data, stream=True)
else:
@ -255,7 +251,7 @@ server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def tinyllamas() -> ServerProcess:
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
@ -266,6 +262,20 @@ class ServerPreset:
server.n_predict = 64
server.seed = 42
return server
@staticmethod
def bert_bge_small() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 512
server.n_batch = 128
server.n_ubatch = 128
server.n_slots = 2
server.seed = 42
server.server_embeddings = True
return server
def multiple_post_requests(