add invalid cases
This commit is contained in:
parent
8aaf69a3ee
commit
879c5ebd25
3 changed files with 58 additions and 0 deletions
|
@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
||||||
assert res.status_code != 200
|
assert res.status_code != 200
|
||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("messages", [
|
||||||
|
None,
|
||||||
|
"string",
|
||||||
|
[123],
|
||||||
|
[{}],
|
||||||
|
[{"role": 123}],
|
||||||
|
[{"role": "system", "content": 123}],
|
||||||
|
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||||
|
[{"role": "system", "content": "test"}, {}],
|
||||||
|
])
|
||||||
|
def test_invalid_chat_completion_req(messages):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"messages": messages,
|
||||||
|
})
|
||||||
|
assert res.status_code == 400 or res.status_code == 500
|
||||||
|
assert "error" in res.body
|
||||||
|
|
|
@ -8,6 +8,7 @@ def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama_infill()
|
server = ServerPreset.tinyllama_infill()
|
||||||
|
|
||||||
|
|
||||||
def test_infill_without_input_extra():
|
def test_infill_without_input_extra():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -19,6 +20,7 @@ def test_infill_without_input_extra():
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
def test_infill_with_input_extra():
|
def test_infill_with_input_extra():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -33,3 +35,23 @@ def test_infill_with_input_extra():
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input_extra", [
|
||||||
|
{},
|
||||||
|
{"filename": "ok"},
|
||||||
|
{"filename": 123},
|
||||||
|
{"filename": 123, "text": "abc"},
|
||||||
|
{"filename": 123, "text": 456},
|
||||||
|
])
|
||||||
|
def test_invalid_input_extra_req(input_extra):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/infill", data={
|
||||||
|
"prompt": "Complete this",
|
||||||
|
"input_extra": [input_extra],
|
||||||
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
||||||
|
"input_suffix": "}\n",
|
||||||
|
})
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
|
@ -36,3 +36,20 @@ def test_rerank():
|
||||||
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
||||||
assert most_relevant["index"] == 2
|
assert most_relevant["index"] == 2
|
||||||
assert least_relevant["index"] == 3
|
assert least_relevant["index"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("documents", [
|
||||||
|
[],
|
||||||
|
None,
|
||||||
|
123,
|
||||||
|
[1, 2, 3],
|
||||||
|
])
|
||||||
|
def test_invalid_rerank_req(documents):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/rerank", data={
|
||||||
|
"query": "Machine learning is",
|
||||||
|
"documents": documents,
|
||||||
|
})
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue