more tests
This commit is contained in:
parent
3249aabc0b
commit
f09a9b68e1
7 changed files with 417 additions and 36 deletions
|
@ -1,13 +1,13 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_server_start_simple():
|
||||
|
@ -23,3 +23,12 @@ def test_server_props():
|
|||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert res.body["total_slots"] == server.n_slots
|
||||
|
||||
|
||||
def test_server_models():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["data"]) == 1
|
||||
assert res.body["data"][0]["id"] == server.model_alias
|
||||
|
|
129
examples/server/tests/unit/test_chat_completion.py
Normal file
129
examples/server/tests/unit/test_chat_completion.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
if truncated:
|
||||
assert choice["finish_reason"] == "length"
|
||||
else:
|
||||
assert choice["finish_reason"] == "stop"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
assert "content" not in choice["delta"]
|
||||
assert match_regex(re_content, content)
|
||||
# FIXME: not sure why this is incorrect in stream mode
|
||||
# if truncated:
|
||||
# assert choice["finish_reason"] == "length"
|
||||
# else:
|
||||
# assert choice["finish_reason"] == "stop"
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
|
||||
|
||||
def test_chat_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=8,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "stop"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||
({"type": "json_object"}, 10, "(\\{|John)+"),
|
||||
({"type": "sound"}, 0, None),
|
||||
# invalid response format (expected to fail)
|
||||
({"type": "json_object", "schema": 123}, 0, None),
|
||||
({"type": "json_object", "schema": {"type": 123}}, 0, None),
|
||||
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
|
||||
])
|
||||
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predicted,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"response_format": response_format,
|
||||
})
|
||||
if re_content is not None:
|
||||
assert res.status_code == 200
|
||||
choice = res.body["choices"][0]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
|
|
@ -2,13 +2,13 @@ import pytest
|
|||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
|
@ -61,10 +61,62 @@ def test_completion_with_openai_library():
|
|||
res = client.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="I believe the meaning of life is",
|
||||
n=8,
|
||||
max_tokens=8,
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.start()
|
||||
last_res = None
|
||||
for _ in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_different_result_different_seed(n_slots: int):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.start()
|
||||
last_res = None
|
||||
for seed in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": seed,
|
||||
"temperature": 1.0,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] != last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_batch", [16, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.0, 1.0])
|
||||
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
server.start()
|
||||
last_res = None
|
||||
for _ in range(4):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": temperature,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
|
98
examples/server/tests/unit/test_embedding.py
Normal file
98
examples/server/tests/unit/test_embedding.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.bert_bge_small()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.bert_bge_small()
|
||||
|
||||
|
||||
def test_embedding_single():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "I believe the meaning of life is",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 1
|
||||
assert 'embedding' in res.body['data'][0]
|
||||
assert len(res.body['data'][0]['embedding']) > 1
|
||||
|
||||
# make sure embedding vector is normalized
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-6
|
||||
|
||||
|
||||
def test_embedding_multiple():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
"This is a test",
|
||||
"This is another test",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 4
|
||||
for d in res.body['data']:
|
||||
assert 'embedding' in d
|
||||
assert len(d['embedding']) > 1
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||
assert len(res.data) == 1
|
||||
assert len(res.data[0].embedding) > 1
|
||||
|
||||
|
||||
def test_embedding_openai_library_multiple():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||
"I believe the meaning of life is",
|
||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||
"This is a test",
|
||||
"This is another test",
|
||||
])
|
||||
assert len(res.data) == 4
|
||||
for d in res.data:
|
||||
assert len(d.embedding) > 1
|
||||
|
||||
|
||||
def test_embedding_error_prompt_too_long():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "This is a test " * 512,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "too large" in res.body["error"]["message"]
|
||||
|
||||
|
||||
def test_same_prompt_give_same_result():
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body['data']) == 5
|
||||
for i in range(1, len(res.body['data'])):
|
||||
v0 = res.body['data'][0]['embedding']
|
||||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < 1e-6
|
83
examples/server/tests/unit/test_security.py
Normal file
83
examples/server/tests/unit/test_security.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
TEST_API_KEY = "sk-this-is-the-secret-key"
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.api_key = TEST_API_KEY
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
|
||||
def test_access_public_endpoint(endpoint: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("GET", endpoint)
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
|
||||
def test_incorrect_api_key(api_key: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"Authorization": f"Bearer {api_key}" if api_key else None,
|
||||
})
|
||||
assert res.status_code == 401
|
||||
assert "error" in res.body
|
||||
assert res.body["error"]["type"] == "authentication_error"
|
||||
|
||||
|
||||
def test_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_openai_library_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a chatbot."},
|
||||
{"role": "user", "content": "What is the meaning of life?"},
|
||||
],
|
||||
)
|
||||
assert len(res.choices) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
|
||||
("localhost", "Access-Control-Allow-Origin", "localhost"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
|
||||
("origin", "Access-Control-Allow-Credentials", "true"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
|
||||
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
|
||||
])
|
||||
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("OPTIONS", "/completions", headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Authorization",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert cors_header in res.headers
|
||||
assert res.headers[cors_header] == cors_header_value
|
|
@ -1,13 +1,13 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllamas()
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_tokenize_detokenize():
|
||||
|
@ -15,17 +15,17 @@ def test_tokenize_detokenize():
|
|||
server.start()
|
||||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
assert len(resTok.body["tokens"]) > 5
|
||||
assert res_tok.status_code == 200
|
||||
assert len(res_tok.body["tokens"]) > 5
|
||||
# detokenize
|
||||
resDetok = server.make_request("POST", "/detokenize", data={
|
||||
"tokens": resTok.body["tokens"],
|
||||
res_detok = server.make_request("POST", "/detokenize", data={
|
||||
"tokens": res_tok.body["tokens"],
|
||||
})
|
||||
assert resDetok.status_code == 200
|
||||
assert resDetok.body["content"].strip() == content
|
||||
assert res_detok.status_code == 200
|
||||
assert res_detok.body["content"].strip() == content
|
||||
|
||||
|
||||
def test_tokenize_with_bos():
|
||||
|
@ -34,12 +34,12 @@ def test_tokenize_with_bos():
|
|||
# tokenize
|
||||
content = "What is the capital of France ?"
|
||||
bosId = 1
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"add_special": True,
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
assert resTok.body["tokens"][0] == bosId
|
||||
assert res_tok.status_code == 200
|
||||
assert res_tok.body["tokens"][0] == bosId
|
||||
|
||||
|
||||
def test_tokenize_with_pieces():
|
||||
|
@ -47,12 +47,12 @@ def test_tokenize_with_pieces():
|
|||
server.start()
|
||||
# tokenize
|
||||
content = "This is a test string with unicode 媽 and emoji 🤗"
|
||||
resTok = server.make_request("POST", "/tokenize", data={
|
||||
res_tok = server.make_request("POST", "/tokenize", data={
|
||||
"content": content,
|
||||
"with_pieces": True,
|
||||
})
|
||||
assert resTok.status_code == 200
|
||||
for token in resTok.body["tokens"]:
|
||||
assert res_tok.status_code == 200
|
||||
for token in res_tok.body["tokens"]:
|
||||
assert "id" in token
|
||||
assert token["id"] > 0
|
||||
assert "piece" in token
|
||||
|
|
|
@ -58,13 +58,12 @@ class ServerProcess:
|
|||
id_slot: int | None = None
|
||||
cache_prompt: bool | None = None
|
||||
n_slots: int | None = None
|
||||
server_api_key: str | None = None
|
||||
server_continuous_batching: bool | None = False
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
draft: int | None = None
|
||||
user_api_key: str | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
lora_file: str | None = None
|
||||
disable_ctx_shift: int | None = False
|
||||
|
@ -129,8 +128,6 @@ class ServerProcess:
|
|||
server_args.extend(["--n-predict", self.n_predict])
|
||||
if self.slot_save_path:
|
||||
server_args.extend(["--slot-save-path", self.slot_save_path])
|
||||
if self.server_api_key:
|
||||
server_args.extend(["--api-key", self.server_api_key])
|
||||
if self.n_ga:
|
||||
server_args.extend(["--grp-attn-n", self.n_ga])
|
||||
if self.n_ga_w:
|
||||
|
@ -141,6 +138,8 @@ class ServerProcess:
|
|||
server_args.extend(["--lora", self.lora_file])
|
||||
if self.disable_ctx_shift:
|
||||
server_args.extend(["--no-context-shift"])
|
||||
if self.api_key:
|
||||
server_args.extend(["--api-key", self.api_key])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
|
@ -180,7 +179,9 @@ class ServerProcess:
|
|||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
response = self.make_request("GET", "/slots")
|
||||
response = self.make_request("GET", "/slots", headers={
|
||||
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
||||
})
|
||||
if response.status_code == 200:
|
||||
self.ready = True
|
||||
return # server is ready
|
||||
|
@ -205,15 +206,13 @@ class ServerProcess:
|
|||
headers: dict | None = None,
|
||||
) -> ServerResponse:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
headers = {}
|
||||
if self.user_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
||||
if self.response_format:
|
||||
headers["Accept"] = self.response_format
|
||||
parse_body = False
|
||||
if method == "GET":
|
||||
response = requests.get(url, headers=headers)
|
||||
parse_body = True
|
||||
elif method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
parse_body = True
|
||||
elif method == "OPTIONS":
|
||||
response = requests.options(url, headers=headers)
|
||||
else:
|
||||
|
@ -221,10 +220,10 @@ class ServerProcess:
|
|||
result = ServerResponse()
|
||||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
result.body = response.json()
|
||||
result.body = response.json() if parse_body else None
|
||||
print("Response from server", result.body)
|
||||
return result
|
||||
|
||||
|
||||
def make_stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
@ -233,9 +232,6 @@ class ServerProcess:
|
|||
headers: dict | None = None,
|
||||
) -> Iterator[dict]:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
headers = {}
|
||||
if self.user_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.user_api_key}"
|
||||
if method == "POST":
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
else:
|
||||
|
@ -255,7 +251,7 @@ server_instances: Set[ServerProcess] = set()
|
|||
|
||||
class ServerPreset:
|
||||
@staticmethod
|
||||
def tinyllamas() -> ServerProcess:
|
||||
def tinyllama2() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||
|
@ -266,6 +262,20 @@ class ServerPreset:
|
|||
server.n_predict = 64
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def bert_bge_small() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||
server.model_alias = "bert-bge-small"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 128
|
||||
server.n_ubatch = 128
|
||||
server.n_slots = 2
|
||||
server.seed = 42
|
||||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
|
||||
def multiple_post_requests(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue