From 93a5245b0e21f47cc0c0777181cb44ec57ae8e39 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 10 Dec 2024 01:11:08 +0000 Subject: [PATCH] tool-calls: migrate tests to pytest --- common/tool-call.cpp | 6 +- .../server/tests/features/tool_call.feature | 163 ------------------ examples/server/tests/pytest.ini | 4 + examples/server/tests/tests.sh | 2 +- .../server/tests/unit/test_chat_completion.py | 156 +++++++++++++++++ examples/server/tests/utils.py | 6 + .../meta-llama-Llama-3.3-70B-Instruct.jinja | 109 ++++++++++++ tests/test-tool-call.cpp | 1 + 8 files changed, 282 insertions(+), 165 deletions(-) delete mode 100644 examples/server/tests/features/tool_call.feature create mode 100644 examples/server/tests/pytest.ini create mode 100644 tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja diff --git a/common/tool-call.cpp b/common/tool-call.cpp index b209c9145..3523b28b4 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -383,7 +383,11 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages json messages_with_system = messages; if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { - messages_with_system.at(0).at("content") += ("\n" + system_prompt); + std::string existing_system = messages_with_system.at(0).at("content"); + messages_with_system[0] = json { + {"role", "system"}, + {"content", existing_system + "\n" + system_prompt}, + }; } else { messages_with_system.insert(messages_with_system.begin(), json { {"role", "system"}, diff --git a/examples/server/tests/features/tool_call.feature b/examples/server/tests/features/tool_call.feature deleted file mode 100644 index a0d99e452..000000000 --- a/examples/server/tests/features/tool_call.feature +++ /dev/null @@ -1,163 +0,0 @@ -@llama.cpp -@server -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And BOS token is 1 - And 42 as server seed - And greedy sampling - And 8192 KV cache size - And 32 as batch size - And 1 slots - And prometheus compatible metrics exposed - And jinja templates are enabled - - - Scenario Outline: Template + tinystories model w/ required tool_choice yields tool call - Given a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a test chat template file named - And the server is starting - And the server is healthy - And a model test - And max tokens to predict - And a user prompt say hello world with python - And a tool choice required - And tool - And parallel tool calls is - And an OAI compatible chat completions request with no api error - Then tool is called with arguments - - Examples: Prompts - | template_name | n_predict | tool_name | tool_arguments | parallel_tool_calls | - | meetkai-functionary-medium-v3.1 | 32 | test | {} | disabled | - | meetkai-functionary-medium-v3.1 | 32 | python | {"code": ". She was so excited to go to the park and s"} | disabled | - | meetkai-functionary-medium-v3.2 | 32 | test | {} | disabled | - | meetkai-functionary-medium-v3.2 | 32 | python | {"code": "Yes,"} | disabled | - | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use | 128 | test | {} | disabled | - | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use | 128 | python | {"code": "Yes,"} | disabled | - | NousResearch-Hermes-3-Llama-3.1-8B-tool_use | 128 | test | {} | disabled | - | NousResearch-Hermes-3-Llama-3.1-8B-tool_use | 128 | python | {"code": "Yes,"} | disabled | - | meta-llama-Meta-Llama-3.1-8B-Instruct | 128 | test | {} | disabled | - | meta-llama-Meta-Llama-3.1-8B-Instruct | 128 | python | {"code": "It's a shark."} | disabled | - | meta-llama-Llama-3.2-3B-Instruct | 128 | test | {} | disabled | - | meta-llama-Llama-3.2-3B-Instruct | 128 | python | {"code": "It's a shark."} | disabled | - | mistralai-Mistral-Nemo-Instruct-2407 | 128 | test | {} | disabled | - | mistralai-Mistral-Nemo-Instruct-2407 | 128 | python | {"code": "It's a small cost."} | disabled | - - - Scenario Outline: Template + tinystories model yields no tool call - Given a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a test chat template file named - And the server is starting - And the server is healthy - And a model test - And max tokens to predict - And a user prompt say hello world with python - And tools [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] - And an OAI compatible chat completions request with no api error - Then no tool is called - - Examples: Prompts - | template_name | n_predict | - | meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | - | meetkai-functionary-medium-v3.1 | 128 | - | meetkai-functionary-medium-v3.2 | 128 | - - - Scenario: Tool call template + tinystories and no tool won't call any tool - Given a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a test chat template file named meta-llama-Meta-Llama-3.1-8B-Instruct - And the server is starting - And the server is healthy - And a model test - And 16 max tokens to predict - And a user prompt say hello world with python - And tools [] - And an OAI compatible chat completions request with no api error - Then no tool is called - - - @slow - Scenario Outline: Python hello world w/ + tool yields python call - Given a model file from HF repo - And a test chat template file named - And no warmup - And the server is starting - And the server is healthy - And a model test - And 256 max tokens to predict - And a user prompt say hello world with python - And tool - And parallel tool calls is disabled - And an OAI compatible chat completions request with no api error - Then tool python is called with arguments - - Examples: Prompts - | tool | tool_arguments | hf_repo | hf_file | template_override | - | python | {"code": "print('Hello, world!')"} | bartowski/gemma-2-2b-it-GGUF | gemma-2-2b-it-Q4_K_M.gguf | | - | python | {"code": "print('Hello, World!')"} | bartowski/Mistral-Nemo-Instruct-2407-GGUF | Mistral-Nemo-Instruct-2407-Q4_K_M.gguf | | - | python | {"code": "print(\"Hello World\")"} | bartowski/Qwen2.5-7B-Instruct-GGUF | Qwen2.5-7B-Instruct-Q4_K_M.gguf | | - | python | {"code": "print('Hello, World!')"} | bartowski/Phi-3.5-mini-instruct-GGUF | Phi-3.5-mini-instruct-Q4_K_M.gguf | | - | python | {"code": "print('Hello, world!')"} | NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF | Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use | - | python | {"code": "print('hello world')"} | NousResearch/Hermes-3-Llama-3.1-8B-GGUF | Hermes-3-Llama-3.1-8B.Q4_K_M.gguf | NousResearch-Hermes-3-Llama-3.1-8B-tool_use | - | python | {"code": "print('Hello, World!'}"} | bartowski/Llama-3.2-1B-Instruct-GGUF | Llama-3.2-1B-Instruct-Q4_K_M.gguf | meta-llama-Llama-3.2-3B-Instruct | - | python | {"code": "print("} | bartowski/Llama-3.2-3B-Instruct-GGUF | Llama-3.2-3B-Instruct-Q4_K_M.gguf | meta-llama-Llama-3.2-3B-Instruct | - | python | {"code": "print("} | lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF | Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf | | - | python | {"code": "print('Hello, World!')"} | bartowski/functionary-small-v3.2-GGUF | functionary-small-v3.2-Q8_0.gguf | meetkai-functionary-medium-v3.2 | - | code_interpreter | {"code": "print('Hello, world!')"} | bartowski/gemma-2-2b-it-GGUF | gemma-2-2b-it-Q4_K_M.gguf | | - | code_interpreter | {"code": "print('Hello, World!')"} | bartowski/Mistral-Nemo-Instruct-2407-GGUF | Mistral-Nemo-Instruct-2407-Q4_K_M.gguf | mistralai-Mistral-Nemo-Instruct-2407 | - | code_interpreter | {"code": "print(\"Hello World\")"} | bartowski/Qwen2.5-7B-Instruct-GGUF | Qwen2.5-7B-Instruct-Q4_K_M.gguf | | - | code_interpreter | {"code": "print('Hello, World!')"} | bartowski/Phi-3.5-mini-instruct-GGUF | Phi-3.5-mini-instruct-Q4_K_M.gguf | | - | code_interpreter | {"code": "print('Hello, world!')"} | NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF | Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use | - | code_interpreter | {"code": "print('hello world')"} | NousResearch/Hermes-3-Llama-3.1-8B-GGUF | Hermes-3-Llama-3.1-8B.Q4_K_M.gguf | NousResearch-Hermes-3-Llama-3.1-8B-tool_use | - | code_interpreter | {"code": "print('Hello, World!'}"} | lmstudio-community/Llama-3.2-1B-Instruct-GGUF | Llama-3.2-1B-Instruct-Q4_K_M.gguf | meta-llama-Llama-3.2-3B-Instruct | - | code_interpreter | {"code": "print("} | lmstudio-community/Llama-3.2-3B-Instruct-GGUF | Llama-3.2-3B-Instruct-Q4_K_M.gguf | meta-llama-Llama-3.2-3B-Instruct | - | code_interpreter | {"code": "print("} | lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF | Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf | | - | code_interpreter | {"code": "print('Hello, World!')"} | bartowski/functionary-small-v3.2-GGUF | functionary-small-v3.2-Q8_0.gguf | meetkai-functionary-medium-v3.2 | - - - @slow - Scenario Outline: Python hello world w/o tools yields no tool call - Given a model file Phi-3.5-mini-instruct-Q4_K_M.gguf from HF repo bartowski/Phi-3.5-mini-instruct-GGUF - And no warmup - And the server is starting - And the server is healthy - And a model test - And 256 max tokens to predict - And a user prompt say hello world with python - And parallel tool calls is disabled - And an OAI compatible chat completions request with no api error - Then no tool is called - - - @slow - Scenario Outline: Python hello world w/o none tool_choice yields no tool call - Given a model file Phi-3.5-mini-instruct-Q4_K_M.gguf from HF repo bartowski/Phi-3.5-mini-instruct-GGUF - And no warmup - And the server is starting - And the server is healthy - And a model test - And 256 max tokens to predict - And a user prompt say hello world with python - And a tool choice none - And python tool - And parallel tool calls is disabled - And an OAI compatible chat completions request with no api error - Then no tool is called - - - @slow - Scenario: Parallel tool calls - Given a model file Mistral-Nemo-Instruct-2407-Q4_K_M.gguf from HF repo bartowski/Mistral-Nemo-Instruct-2407-GGUF - And a test chat template file named mistralai-Mistral-Nemo-Instruct-2407 - And no warmup - And the server is starting - And the server is healthy - And a model test - And 512 max tokens to predict - And a user prompt get the weather in paris and search for llama.cpp's latest commits (don't write comments in the code) - And python tool - And parallel tool calls is enabled - And an OAI compatible chat completions request with no api error - Then receiving the following tool calls: [{"arguments": {"code": "import requests\nresponse = requests.get('https://api.openweathermap.org/data/2.9/weather?q=Paris&appid=YOUR_API_KEY')\nprint(response.json())"}, "name": "ipython" , "id": "123456789"}, {"arguments": {"code": "!git log --oneline --after 2024-01-01 --before 2024-12-31 llama.cpp" }, "name": "ipython" , "id": "987654321"}] diff --git a/examples/server/tests/pytest.ini b/examples/server/tests/pytest.ini new file mode 100644 index 000000000..6510c8d98 --- /dev/null +++ b/examples/server/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + serial \ No newline at end of file diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 1e285dcda..f57a9b40f 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -4,7 +4,7 @@ set -eu if [ $# -lt 1 ] then - pytest -v -x + pytest -v -x -m "not slow" else pytest "$@" fi diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 8a439f9ef..d2dab04ca 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -163,3 +163,159 @@ def test_chat_completion_with_timings_per_token(): assert "predicted_per_second" in data["timings"] assert "predicted_n" in data["timings"] assert data["timings"]["predicted_n"] <= 10 + + +TEST_TOOL = { + "type":"function", + "function": { + "name": "test", + "description": "", + "parameters": { + "type": "object", + "properties": {} + } + } +} + +PYTHON_TOOL = { + "type": "function", + "function": { + "name": "python", + "description": "Runs code in a Python interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the Python interpreter." + } + }, + "required": ["code"] + } + } +} + +CODE_INTEPRETER_TOOL = { + "type": "code_interpreter", +} + + +@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ + ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), + ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ), + ("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ), + ("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), + ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), + ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ), +]) +def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): + global server + server.use_jinja = True + server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": tool["function"]["name"], + "tools": [tool], + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool["function"]["name"] == tool_call["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" + + +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.1", 32, [], None), + ("meetkai-functionary-medium-v3.1", 32, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 32, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.2", 32, [], None), + ("meetkai-functionary-medium-v3.2", 32, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 32, [PYTHON_TOOL], 'none'), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + global server + server.use_jinja = True + server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}' + + +@pytest.mark.slow +@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ + (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print(\"Hello World\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", ("mistralai-Mistral-Nemo-Instruct-2407", None)), + (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello World\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), + (CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "lmstudio-community/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), +]) +def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): + global server + server.use_jinja = True + server.model_hf_repo = hf_repo + server.model_hf_file = hf_file + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/fetch_server_test_models.py {template_hf_repo} {template_variant}` to download the template." + server.start(timeout_seconds=15*60) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": [tool], + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool["function"]["name"] == tool_call["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}" diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index e17a05ff6..65080402a 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -67,6 +67,8 @@ class ServerProcess: draft: int | None = None api_key: str | None = None response_format: str | None = None + chat_template_file: str | None = None + use_jinja: bool | None = None lora_files: List[str] | None = None disable_ctx_shift: int | None = False draft_min: int | None = None @@ -148,6 +150,10 @@ class ServerProcess: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.use_jinja: + server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: diff --git a/tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja b/tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja new file mode 100644 index 000000000..33089ace1 --- /dev/null +++ b/tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index d112e395e..f21af000b 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -414,6 +414,7 @@ static void test_grammars() { test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + test_template("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "", "", { "<|eot_id|>" }, tool_call_message, tools);