added all sequential tests
This commit is contained in:
parent
eb02373f76
commit
472e128c0b
8 changed files with 332 additions and 6 deletions
|
@ -120,3 +120,5 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
|
|||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
# TODO: add completion with tokens as input, mixed token+string input
|
||||
|
|
67
examples/server/tests/unit/test_ctx_shift.py
Normal file
67
examples/server/tests/unit/test_ctx_shift.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
LONG_TEXT = """
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||
""".strip()
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.n_ctx = 256
|
||||
server.n_slots = 2
|
||||
|
||||
|
||||
def test_ctx_shift_enabled():
|
||||
# the prompt is 301 tokens
|
||||
# the slot context is 256/2 = 128 tokens
|
||||
# the prompt is truncated to keep the last 109 tokens
|
||||
# 64 tokens are generated thanks to shifting the context when it gets full
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == 109
|
||||
assert res.body["timings"]["predicted_n"] == 64
|
||||
assert res.body["truncated"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
|
||||
(64, 64, False),
|
||||
(-1, 120, True),
|
||||
])
|
||||
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
||||
global server
|
||||
server.disable_ctx_shift = True
|
||||
server.n_predict = -1
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": "Hi how are you",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["predicted_n"] == n_token_output
|
||||
assert res.body["truncated"] == truncated
|
||||
|
||||
|
||||
def test_ctx_shift_disabled_long_prompt():
|
||||
global server
|
||||
server.disable_ctx_shift = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
assert "exceeds the available context size" in res.body["error"]["message"]
|
|
@ -4,6 +4,7 @@ from utils import *
|
|||
|
||||
server = ServerPreset.bert_bge_small()
|
||||
|
||||
EPSILON = 1e-3
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
|
@ -23,7 +24,7 @@ def test_embedding_single():
|
|||
assert len(res.body['data'][0]['embedding']) > 1
|
||||
|
||||
# make sure embedding vector is normalized
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-5
|
||||
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
|
||||
|
||||
|
||||
def test_embedding_multiple():
|
||||
|
@ -95,4 +96,4 @@ def test_same_prompt_give_same_result():
|
|||
v0 = res.body['data'][0]['embedding']
|
||||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < 1e-5
|
||||
assert abs(x - y) < EPSILON
|
||||
|
|
35
examples/server/tests/unit/test_infill.py
Normal file
35
examples/server/tests/unit/test_infill.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama_infill()
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama_infill()
|
||||
|
||||
def test_infill_without_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"prompt": "Complete this",
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
||||
|
||||
def test_infill_with_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"prompt": "Complete this",
|
||||
"input_extra": [{
|
||||
"filename": "llama.h",
|
||||
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
||||
}],
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
42
examples/server/tests/unit/test_lora.py
Normal file
42
examples/server/tests/unit/test_lora.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
import os
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.stories15m_moe()
|
||||
# download lora file if needed
|
||||
file_name = LORA_FILE_URL.split('/').pop()
|
||||
lora_file = f'../../../{file_name}'
|
||||
if not os.path.exists(lora_file):
|
||||
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
|
||||
with open(lora_file, 'wb') as f:
|
||||
f.write(requests.get(LORA_FILE_URL).content)
|
||||
print(f"Done downloading lora file")
|
||||
server.lora_files = [lora_file]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale,re_content", [
|
||||
# without applying lora, the model should behave like a bedtime story generator
|
||||
(0.0, "(little|girl|three|years|old)+"),
|
||||
# with lora, the model should behave like a Shakespearean text generator
|
||||
(1.0, "(eye|love|glass|sun)+"),
|
||||
])
|
||||
def test_lora(scale: float, re_content: str):
|
||||
global server
|
||||
server.start()
|
||||
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
|
||||
{"id": 0, "scale": scale}
|
||||
])
|
||||
assert res_lora_control.status_code == 200
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Look in thy glass",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
38
examples/server/tests/unit/test_rerank.py
Normal file
38
examples/server/tests/unit/test_rerank.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.jina_reranker_tiny()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.jina_reranker_tiny()
|
||||
|
||||
|
||||
def test_rerank():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"documents": [
|
||||
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
||||
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
||||
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
||||
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == 4
|
||||
|
||||
most_relevant = res.body["results"][0]
|
||||
least_relevant = res.body["results"][0]
|
||||
for doc in res.body["results"]:
|
||||
if doc["relevance_score"] > most_relevant["relevance_score"]:
|
||||
most_relevant = doc
|
||||
if doc["relevance_score"] < least_relevant["relevance_score"]:
|
||||
least_relevant = doc
|
||||
|
||||
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
||||
assert most_relevant["index"] == 2
|
||||
assert least_relevant["index"] == 3
|
97
examples/server/tests/unit/test_slot_save.py
Normal file
97
examples/server/tests/unit/test_slot_save.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.slot_save_path = "./tmp"
|
||||
|
||||
|
||||
def test_slot_save_restore():
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# First prompt in slot 1 should be fully processed
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Lily|cake)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
|
||||
# Save state of slot 1
|
||||
res = server.make_request("POST", "/slots/1?action=save", data={
|
||||
"filename": "slot1.bin",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["n_saved"] == 84
|
||||
|
||||
# Since we have cache, this should only process the last tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
|
||||
|
||||
# Loading the saved cache into slot 0
|
||||
res = server.make_request("POST", "/slots/0?action=restore", data={
|
||||
"filename": "slot1.bin",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["n_restored"] == 84
|
||||
|
||||
# Since we have cache, slot 0 should only process the last tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 0,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
|
||||
|
||||
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 1
|
||||
|
||||
|
||||
def test_slot_erase():
|
||||
global server
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Lily|cake)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
|
||||
# erase slot 1
|
||||
res = server.make_request("POST", "/slots/1?action=erase")
|
||||
assert res.status_code == 200
|
||||
|
||||
# re-run the same prompt, it should process all tokens again
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Lily|cake)+", res.body["content"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
|
@ -17,6 +17,7 @@ from typing import (
|
|||
ContextManager,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
Set,
|
||||
|
@ -65,7 +66,7 @@ class ServerProcess:
|
|||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
lora_file: str | None = None
|
||||
lora_files: List[str] | None = None
|
||||
disable_ctx_shift: int | None = False
|
||||
|
||||
# session variables
|
||||
|
@ -134,8 +135,9 @@ class ServerProcess:
|
|||
server_args.extend(["--grp-attn-w", self.n_ga_w])
|
||||
if self.debug:
|
||||
server_args.append("--verbose")
|
||||
if self.lora_file:
|
||||
server_args.extend(["--lora", self.lora_file])
|
||||
if self.lora_files:
|
||||
for lora_file in self.lora_files:
|
||||
server_args.extend(["--lora", lora_file])
|
||||
if self.disable_ctx_shift:
|
||||
server_args.extend(["--no-context-shift"])
|
||||
if self.api_key:
|
||||
|
@ -202,7 +204,7 @@ class ServerProcess:
|
|||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
data: dict | Any | None = None,
|
||||
headers: dict | None = None,
|
||||
) -> ServerResponse:
|
||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||
|
@ -277,6 +279,48 @@ class ServerPreset:
|
|||
server.server_embeddings = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinyllama_infill() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
||||
server.model_alias = "tinyllama-infill"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
server.n_slots = 1
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def stories15m_moe() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/stories15M_MOE"
|
||||
server.model_hf_file = "stories15M_MOE-F16.gguf"
|
||||
server.model_alias = "stories15m-moe"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
server.n_slots = 1
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def jina_reranker_tiny() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
|
||||
server.model_alias = "jina-reranker"
|
||||
server.model_file = "./tmp/jina-reranker-v1-tiny-en.gguf"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 512
|
||||
server.n_slots = 1
|
||||
server.seed = 42
|
||||
server.server_reranking = True
|
||||
return server
|
||||
|
||||
|
||||
def multiple_post_requests(
|
||||
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue