more tests
This commit is contained in:
parent
3249aabc0b
commit
f09a9b68e1
7 changed files with 417 additions and 36 deletions
|
@ -1,13 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
def test_server_start_simple():
|
def test_server_start_simple():
|
||||||
|
@ -23,3 +23,12 @@ def test_server_props():
|
||||||
res = server.make_request("GET", "/props")
|
res = server.make_request("GET", "/props")
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body["total_slots"] == server.n_slots
|
assert res.body["total_slots"] == server.n_slots
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_models():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("GET", "/models")
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body["data"]) == 1
|
||||||
|
assert res.body["data"][0]["id"] == server.model_alias
|
||||||
|
|
129
examples/server/tests/unit/test_chat_completion.py
Normal file
129
examples/server/tests/unit/test_chat_completion.py
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||||
|
[
|
||||||
|
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||||
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert "assistant" == choice["message"]["role"]
|
||||||
|
assert match_regex(re_content, choice["message"]["content"])
|
||||||
|
if truncated:
|
||||||
|
assert choice["finish_reason"] == "length"
|
||||||
|
else:
|
||||||
|
assert choice["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||||
|
[
|
||||||
|
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||||
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
})
|
||||||
|
content = ""
|
||||||
|
for data in res:
|
||||||
|
choice = data["choices"][0]
|
||||||
|
if choice["finish_reason"] in ["stop", "length"]:
|
||||||
|
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert data["usage"]["completion_tokens"] == n_predicted
|
||||||
|
assert "content" not in choice["delta"]
|
||||||
|
assert match_regex(re_content, content)
|
||||||
|
# FIXME: not sure why this is incorrect in stream mode
|
||||||
|
# if truncated:
|
||||||
|
# assert choice["finish_reason"] == "length"
|
||||||
|
# else:
|
||||||
|
# assert choice["finish_reason"] == "stop"
|
||||||
|
else:
|
||||||
|
assert choice["finish_reason"] is None
|
||||||
|
content += choice["delta"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
|
res = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo-instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "Book"},
|
||||||
|
{"role": "user", "content": "What is the best book"},
|
||||||
|
],
|
||||||
|
max_tokens=8,
|
||||||
|
seed=42,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
print(res)
|
||||||
|
assert res.choices[0].finish_reason == "stop"
|
||||||
|
assert res.choices[0].message.content is not None
|
||||||
|
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||||
|
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||||
|
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||||
|
({"type": "json_object"}, 10, "(\\{|John)+"),
|
||||||
|
({"type": "sound"}, 0, None),
|
||||||
|
# invalid response format (expected to fail)
|
||||||
|
({"type": "json_object", "schema": 123}, 0, None),
|
||||||
|
({"type": "json_object", "schema": {"type": 123}}, 0, None),
|
||||||
|
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
|
||||||
|
])
|
||||||
|
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": n_predicted,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
|
{"role": "user", "content": "Write an example"},
|
||||||
|
],
|
||||||
|
"response_format": response_format,
|
||||||
|
})
|
||||||
|
if re_content is not None:
|
||||||
|
assert res.status_code == 200
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert match_regex(re_content, choice["message"]["content"])
|
||||||
|
else:
|
||||||
|
assert res.status_code != 200
|
||||||
|
assert "error" in res.body
|
||||||
|
|
|
@ -2,13 +2,13 @@ import pytest
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||||
|
@ -61,10 +61,62 @@ def test_completion_with_openai_library():
|
||||||
res = client.completions.create(
|
res = client.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
prompt="I believe the meaning of life is",
|
prompt="I believe the meaning of life is",
|
||||||
n=8,
|
max_tokens=8,
|
||||||
seed=42,
|
seed=42,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
)
|
)
|
||||||
print(res)
|
print(res)
|
||||||
assert res.choices[0].finish_reason == "length"
|
assert res.choices[0].finish_reason == "length"
|
||||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||||
|
def test_consistent_result_same_seed(n_slots: int):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.start()
|
||||||
|
last_res = None
|
||||||
|
for _ in range(4):
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 1.0,
|
||||||
|
})
|
||||||
|
if last_res is not None:
|
||||||
|
assert res.body["content"] == last_res.body["content"]
|
||||||
|
last_res = res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||||
|
def test_different_result_different_seed(n_slots: int):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.start()
|
||||||
|
last_res = None
|
||||||
|
for seed in range(4):
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"seed": seed,
|
||||||
|
"temperature": 1.0,
|
||||||
|
})
|
||||||
|
if last_res is not None:
|
||||||
|
assert res.body["content"] != last_res.body["content"]
|
||||||
|
last_res = res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_batch", [16, 32])
|
||||||
|
@pytest.mark.parametrize("temperature", [0.0, 1.0])
|
||||||
|
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
|
||||||
|
global server
|
||||||
|
server.n_batch = n_batch
|
||||||
|
server.start()
|
||||||
|
last_res = None
|
||||||
|
for _ in range(4):
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": temperature,
|
||||||
|
})
|
||||||
|
if last_res is not None:
|
||||||
|
assert res.body["content"] == last_res.body["content"]
|
||||||
|
last_res = res
|
||||||
|
|
98
examples/server/tests/unit/test_embedding.py
Normal file
98
examples/server/tests/unit/test_embedding.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.bert_bge_small()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.bert_bge_small()
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_single():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": "I believe the meaning of life is",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body['data']) == 1
|
||||||
|
assert 'embedding' in res.body['data'][0]
|
||||||
|
assert len(res.body['data'][0]['embedding']) > 1
|
||||||
|
|
||||||
|
# make sure embedding vector is normalized
|
||||||
|
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_multiple():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": [
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
|
"This is a test",
|
||||||
|
"This is another test",
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body['data']) == 4
|
||||||
|
for d in res.body['data']:
|
||||||
|
assert 'embedding' in d
|
||||||
|
assert len(d['embedding']) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_openai_library_single():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
|
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||||
|
assert len(res.data) == 1
|
||||||
|
assert len(res.data[0].embedding) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_openai_library_multiple():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
|
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
|
"This is a test",
|
||||||
|
"This is another test",
|
||||||
|
])
|
||||||
|
assert len(res.data) == 4
|
||||||
|
for d in res.data:
|
||||||
|
assert len(d.embedding) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_error_prompt_too_long():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": "This is a test " * 512,
|
||||||
|
})
|
||||||
|
assert res.status_code != 200
|
||||||
|
assert "too large" in res.body["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_prompt_give_same_result():
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": [
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body['data']) == 5
|
||||||
|
for i in range(1, len(res.body['data'])):
|
||||||
|
v0 = res.body['data'][0]['embedding']
|
||||||
|
vi = res.body['data'][i]['embedding']
|
||||||
|
for x, y in zip(v0, vi):
|
||||||
|
assert abs(x - y) < 1e-6
|
83
examples/server/tests/unit/test_security.py
Normal file
83
examples/server/tests/unit/test_security.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
TEST_API_KEY = "sk-this-is-the-secret-key"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
server.api_key = TEST_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
|
||||||
|
def test_access_public_endpoint(endpoint: str):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("GET", endpoint)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "error" not in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
|
||||||
|
def test_incorrect_api_key(api_key: str):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completions", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
}, headers={
|
||||||
|
"Authorization": f"Bearer {api_key}" if api_key else None,
|
||||||
|
})
|
||||||
|
assert res.status_code == 401
|
||||||
|
assert "error" in res.body
|
||||||
|
assert res.body["error"]["type"] == "authentication_error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_api_key():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completions", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
}, headers={
|
||||||
|
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "error" not in res.body
|
||||||
|
assert "content" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_library_correct_api_key():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
|
res = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a chatbot."},
|
||||||
|
{"role": "user", "content": "What is the meaning of life?"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert len(res.choices) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
|
||||||
|
("localhost", "Access-Control-Allow-Origin", "localhost"),
|
||||||
|
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
|
||||||
|
("origin", "Access-Control-Allow-Credentials", "true"),
|
||||||
|
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
|
||||||
|
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
|
||||||
|
])
|
||||||
|
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("OPTIONS", "/completions", headers={
|
||||||
|
"Origin": origin,
|
||||||
|
"Access-Control-Request-Method": "POST",
|
||||||
|
"Access-Control-Request-Headers": "Authorization",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert cors_header in res.headers
|
||||||
|
assert res.headers[cors_header] == cors_header_value
|
|
@ -1,13 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllamas()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
def test_tokenize_detokenize():
|
def test_tokenize_detokenize():
|
||||||
|
@ -15,17 +15,17 @@ def test_tokenize_detokenize():
|
||||||
server.start()
|
server.start()
|
||||||
# tokenize
|
# tokenize
|
||||||
content = "What is the capital of France ?"
|
content = "What is the capital of France ?"
|
||||||
resTok = server.make_request("POST", "/tokenize", data={
|
res_tok = server.make_request("POST", "/tokenize", data={
|
||||||
"content": content
|
"content": content
|
||||||
})
|
})
|
||||||
assert resTok.status_code == 200
|
assert res_tok.status_code == 200
|
||||||
assert len(resTok.body["tokens"]) > 5
|
assert len(res_tok.body["tokens"]) > 5
|
||||||
# detokenize
|
# detokenize
|
||||||
resDetok = server.make_request("POST", "/detokenize", data={
|
res_detok = server.make_request("POST", "/detokenize", data={
|
||||||
"tokens": resTok.body["tokens"],
|
"tokens": res_tok.body["tokens"],
|
||||||
})
|
})
|
||||||
assert resDetok.status_code == 200
|
assert res_detok.status_code == 200
|
||||||
assert resDetok.body["content"].strip() == content
|
assert res_detok.body["content"].strip() == content
|
||||||
|
|
||||||
|
|
||||||
def test_tokenize_with_bos():
|
def test_tokenize_with_bos():
|
||||||
|
@ -34,12 +34,12 @@ def test_tokenize_with_bos():
|
||||||
# tokenize
|
# tokenize
|
||||||
content = "What is the capital of France ?"
|
content = "What is the capital of France ?"
|
||||||
bosId = 1
|
bosId = 1
|
||||||
resTok = server.make_request("POST", "/tokenize", data={
|
res_tok = server.make_request("POST", "/tokenize", data={
|
||||||
"content": content,
|
"content": content,
|
||||||
"add_special": True,
|
"add_special": True,
|
||||||
})
|
})
|
||||||
assert resTok.status_code == 200
|
assert res_tok.status_code == 200
|
||||||
assert resTok.body["tokens"][0] == bosId
|
assert res_tok.body["tokens"][0] == bosId
|
||||||
|
|
||||||
|
|
||||||
def test_tokenize_with_pieces():
|
def test_tokenize_with_pieces():
|
||||||
|
@ -47,12 +47,12 @@ def test_tokenize_with_pieces():
|
||||||
server.start()
|
server.start()
|
||||||
# tokenize
|
# tokenize
|
||||||
content = "This is a test string with unicode 媽 and emoji 🤗"
|
content = "This is a test string with unicode 媽 and emoji 🤗"
|
||||||
resTok = server.make_request("POST", "/tokenize", data={
|
res_tok = server.make_request("POST", "/tokenize", data={
|
||||||
"content": content,
|
"content": content,
|
||||||
"with_pieces": True,
|
"with_pieces": True,
|
||||||
})
|
})
|
||||||
assert resTok.status_code == 200
|
assert res_tok.status_code == 200
|
||||||
for token in resTok.body["tokens"]:
|
for token in res_tok.body["tokens"]:
|
||||||
assert "id" in token
|
assert "id" in token
|
||||||
assert token["id"] > 0
|
assert token["id"] > 0
|
||||||
assert "piece" in token
|
assert "piece" in token
|
||||||
|
|
|
@ -58,13 +58,12 @@ class ServerProcess:
|
||||||
id_slot: int | None = None
|
id_slot: int | None = None
|
||||||
cache_prompt: bool | None = None
|
cache_prompt: bool | None = None
|
||||||
n_slots: int | None = None
|
n_slots: int | None = None
|
||||||
server_api_key: str | None = None
|
|
||||||
server_continuous_batching: bool | None = False
|
server_continuous_batching: bool | None = False
|
||||||
server_embeddings: bool | None = False
|
server_embeddings: bool | None = False
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
user_api_key: str | None = None
|
api_key: str | None = None
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
lora_file: str | None = None
|
lora_file: str | None = None
|
||||||
disable_ctx_shift: int | None = False
|
disable_ctx_shift: int | None = False
|
||||||
|
@ -129,8 +128,6 @@ class ServerProcess:
|
||||||
server_args.extend(["--n-predict", self.n_predict])
|
server_args.extend(["--n-predict", self.n_predict])
|
||||||
if self.slot_save_path:
|
if self.slot_save_path:
|
||||||
server_args.extend(["--slot-save-path", self.slot_save_path])
|
server_args.extend(["--slot-save-path", self.slot_save_path])
|
||||||
if self.server_api_key:
|
|
||||||
server_args.extend(["--api-key", self.server_api_key])
|
|
||||||
if self.n_ga:
|
if self.n_ga:
|
||||||
server_args.extend(["--grp-attn-n", self.n_ga])
|
server_args.extend(["--grp-attn-n", self.n_ga])
|
||||||
if self.n_ga_w:
|
if self.n_ga_w:
|
||||||
|
@ -141,6 +138,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--lora", self.lora_file])
|
server_args.extend(["--lora", self.lora_file])
|
||||||
if self.disable_ctx_shift:
|
if self.disable_ctx_shift:
|
||||||
server_args.extend(["--no-context-shift"])
|
server_args.extend(["--no-context-shift"])
|
||||||
|
if self.api_key:
|
||||||
|
server_args.extend(["--api-key", self.api_key])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
@ -180,7 +179,9 @@ class ServerProcess:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while time.time() - start_time < timeout_seconds:
|
while time.time() - start_time < timeout_seconds:
|
||||||
try:
|
try:
|
||||||
response = self.make_request("GET", "/slots")
|
response = self.make_request("GET", "/slots", headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
||||||
|
})
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self.ready = True
|
self.ready = True
|
||||||
return # server is ready
|
return # server is ready
|
||||||
|
@ -205,15 +206,13 @@ class ServerProcess:
|
||||||
headers: dict | None = None,
|
headers: dict | None = None,
|
||||||
) -> ServerResponse:
|
) -> ServerResponse:
|
||||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||||
headers = {}
|
parse_body = False
|
||||||
if self.user_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
|
||||||
if self.response_format:
|
|
||||||
headers["Accept"] = self.response_format
|
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
|
parse_body = True
|
||||||
elif method == "POST":
|
elif method == "POST":
|
||||||
response = requests.post(url, headers=headers, json=data)
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
parse_body = True
|
||||||
elif method == "OPTIONS":
|
elif method == "OPTIONS":
|
||||||
response = requests.options(url, headers=headers)
|
response = requests.options(url, headers=headers)
|
||||||
else:
|
else:
|
||||||
|
@ -221,10 +220,10 @@ class ServerProcess:
|
||||||
result = ServerResponse()
|
result = ServerResponse()
|
||||||
result.headers = dict(response.headers)
|
result.headers = dict(response.headers)
|
||||||
result.status_code = response.status_code
|
result.status_code = response.status_code
|
||||||
result.body = response.json()
|
result.body = response.json() if parse_body else None
|
||||||
print("Response from server", result.body)
|
print("Response from server", result.body)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def make_stream_request(
|
def make_stream_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -233,9 +232,6 @@ class ServerProcess:
|
||||||
headers: dict | None = None,
|
headers: dict | None = None,
|
||||||
) -> Iterator[dict]:
|
) -> Iterator[dict]:
|
||||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||||
headers = {}
|
|
||||||
if self.user_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
|
||||||
if method == "POST":
|
if method == "POST":
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
else:
|
else:
|
||||||
|
@ -255,7 +251,7 @@ server_instances: Set[ServerProcess] = set()
|
||||||
|
|
||||||
class ServerPreset:
|
class ServerPreset:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tinyllamas() -> ServerProcess:
|
def tinyllama2() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||||
|
@ -266,6 +262,20 @@ class ServerPreset:
|
||||||
server.n_predict = 64
|
server.n_predict = 64
|
||||||
server.seed = 42
|
server.seed = 42
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bert_bge_small() -> ServerProcess:
|
||||||
|
server = ServerProcess()
|
||||||
|
server.model_hf_repo = "ggml-org/models"
|
||||||
|
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||||
|
server.model_alias = "bert-bge-small"
|
||||||
|
server.n_ctx = 512
|
||||||
|
server.n_batch = 128
|
||||||
|
server.n_ubatch = 128
|
||||||
|
server.n_slots = 2
|
||||||
|
server.seed = 42
|
||||||
|
server.server_embeddings = True
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
def multiple_post_requests(
|
def multiple_post_requests(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue