add invalid cases
This commit is contained in:
parent
8aaf69a3ee
commit
879c5ebd25
3 changed files with 58 additions and 0 deletions
|
@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
|||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", [
|
||||
None,
|
||||
"string",
|
||||
[123],
|
||||
[{}],
|
||||
[{"role": 123}],
|
||||
[{"role": "system", "content": 123}],
|
||||
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||
[{"role": "system", "content": "test"}, {}],
|
||||
])
|
||||
def test_invalid_chat_completion_req(messages):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"messages": messages,
|
||||
})
|
||||
assert res.status_code == 400 or res.status_code == 500
|
||||
assert "error" in res.body
|
||||
|
|
|
@ -8,6 +8,7 @@ def create_server():
|
|||
global server
|
||||
server = ServerPreset.tinyllama_infill()
|
||||
|
||||
|
||||
def test_infill_without_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
|
@ -19,6 +20,7 @@ def test_infill_without_input_extra():
|
|||
assert res.status_code == 200
|
||||
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
||||
|
||||
|
||||
def test_infill_with_input_extra():
|
||||
global server
|
||||
server.start()
|
||||
|
@ -33,3 +35,23 @@ def test_infill_with_input_extra():
|
|||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_extra", [
|
||||
{},
|
||||
{"filename": "ok"},
|
||||
{"filename": 123},
|
||||
{"filename": 123, "text": "abc"},
|
||||
{"filename": 123, "text": 456},
|
||||
])
|
||||
def test_invalid_input_extra_req(input_extra):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/infill", data={
|
||||
"prompt": "Complete this",
|
||||
"input_extra": [input_extra],
|
||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
|
|
@ -36,3 +36,20 @@ def test_rerank():
|
|||
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
||||
assert most_relevant["index"] == 2
|
||||
assert least_relevant["index"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("documents", [
|
||||
[],
|
||||
None,
|
||||
123,
|
||||
[1, 2, 3],
|
||||
])
|
||||
def test_invalid_rerank_req(documents):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/rerank", data={
|
||||
"query": "Machine learning is",
|
||||
"documents": documents,
|
||||
})
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue