diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 39b5c8dbe..135cd6d0c 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -29,7 +29,8 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, assert type(res.body["has_new_line"]) == bool assert match_regex(re_content, res.body["content"]) if return_tokens: - assert res.body["tokens"] != [] + assert len(res.body["tokens"]) > 0 + assert all(type(tok) == int for tok in res.body["tokens"]) else: assert res.body["tokens"] == [] @@ -61,7 +62,8 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp assert data["generation_settings"]["seed"] == server.seed assert match_regex(re_content, content) else: - assert data["tokens"] != [] + assert len(res.body["tokens"]) > 0 + assert all(type(tok) == int for tok in res.body["tokens"]) content += data["content"]