server : fix speculative decoding with context shift (#10641)

* server : fix speculative decoding with context shift

ggml-ci

* server : take into account speculative limits

ggml-ci

* server : add tests
This commit is contained in:
Georgi Gerganov 2024-12-04 22:38:20 +02:00 committed by GitHub
parent 59f4db1088
commit 1da7b76569
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 2 deletions

View file

@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
last_content = res.body["content"]
def test_slot_ctx_not_exceeded():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
def test_with_ctx_shift():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"n_predict": 64,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
assert res.body["tokens_predicted"] == 64
assert res.body["truncated"] == True
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),