server : replace behave with pytest (#10416)
* server : replace behave with pytest * fix test on windows * misc * add more tests * more tests * styling * log less, fix embd test * added all sequential tests * fix coding style * fix save slot test * add parallel completion test * fix parallel test * remove feature files * update test docs * no cache_prompt for some tests * add test_cache_vs_nocache_prompt
This commit is contained in:
parent
0bbd2262a3
commit
45abe0f74e
34 changed files with 1317 additions and 2497 deletions
67
examples/server/tests/unit/test_ctx_shift.py
Normal file
67
examples/server/tests/unit/test_ctx_shift.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
LONG_TEXT = """
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||
""".strip()
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.n_ctx = 256
|
||||
server.n_slots = 2
|
||||
|
||||
|
||||
def test_ctx_shift_enabled():
|
||||
# the prompt is 301 tokens
|
||||
# the slot context is 256/2 = 128 tokens
|
||||
# the prompt is truncated to keep the last 109 tokens
|
||||
# 64 tokens are generated thanks to shifting the context when it gets full
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == 109
|
||||
assert res.body["timings"]["predicted_n"] == 64
|
||||
assert res.body["truncated"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
|
||||
(64, 64, False),
|
||||
(-1, 120, True),
|
||||
])
|
||||
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
|
||||
global server
|
||||
server.disable_ctx_shift = True
|
||||
server.n_predict = -1
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": "Hi how are you",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["predicted_n"] == n_token_output
|
||||
assert res.body["truncated"] == truncated
|
||||
|
||||
|
||||
def test_ctx_shift_disabled_long_prompt():
|
||||
global server
|
||||
server.disable_ctx_shift = True
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
assert "exceeds the available context size" in res.body["error"]["message"]
|
Loading…
Add table
Add a link
Reference in a new issue