From 055aa9e2ea16923200244eca138d5c83e983bb89 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 8 Dec 2024 22:53:00 +0100 Subject: [PATCH] update test --- examples/server/tests/unit/test_infill.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index 4b0133406..e35275709 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -13,28 +13,28 @@ def test_infill_without_input_extra(): global server server.start() res = server.make_request("POST", "/infill", data={ - "prompt": "Complete this", - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) + assert match_regex("(Ann|small|shiny)+", res.body["content"]) def test_infill_with_input_extra(): global server server.start() res = server.make_request("POST", "/infill", data={ - "prompt": "Complete this", "input_extra": [{ "filename": "llama.h", "text": "LLAMA_API int32_t llama_n_threads();\n" }], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(help|find|band)+", res.body["content"]) + assert match_regex("(Dad|excited|park)+", res.body["content"]) @pytest.mark.parametrize("input_extra", [ @@ -65,12 +65,12 @@ def test_with_qwen_model(): server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" server.start(timeout_seconds=600) res = server.make_request("POST", "/infill", data={ - # "prompt": "Complete this", # FIXME: add more complicated prompt when format_infill is fixed "input_extra": [{ "filename": "llama.h", "text": "LLAMA_API int32_t llama_n_threads();\n" }], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 200