diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 5a42c5133..76cab4ef9 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -11,17 +11,19 @@ def create_server(): @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, " blue and shin", 23, 8, "length", True, "This is not a chat template, it is"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): global server server.jinja = jinja + server.chat_template = chat_template server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d1c198063..48474a0ce 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -70,13 +70,13 @@ class ServerProcess: draft: int | None = None api_key: str | None = None lora_files: List[str] | None = None - chat_template_file: str | None = None - jinja: bool | None = None disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + jinja: bool | None = None chat_template: str | None = None + chat_template_file: str | None = None # session variables process: subprocess.Popen | None = None @@ -157,10 +157,6 @@ class ServerProcess: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.chat_template_file: - server_args.extend(["--chat-template-file", self.chat_template_file]) - if self.jinja: - server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: @@ -171,9 +167,13 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.jinja: + server_args.append("--jinja") if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) - + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}")