From cafea60922d4d1d58648594f920dd3f474b48747 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 27 Jan 2025 22:46:33 +0000 Subject: [PATCH] Split e2e test_tool_call from test_chat_completion --- .../server/tests/unit/test_chat_completion.py | 234 ------------- examples/server/tests/unit/test_tool_call.py | 317 ++++++++++++++++++ 2 files changed, 317 insertions(+), 234 deletions(-) create mode 100644 examples/server/tests/unit/test_tool_call.py diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 91d629e24..fa6cbeb67 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -188,240 +188,6 @@ def test_chat_completion_with_timings_per_token(): assert data["timings"]["predicted_n"] <= 10 -TEST_TOOL = { - "type":"function", - "function": { - "name": "test", - "description": "", - "parameters": { - "type": "object", - "properties": { - "success": {"type": "boolean", "const": True}, - }, - "required": ["success"] - } - } -} - -PYTHON_TOOL = { - "type": "function", - "function": { - "name": "python", - "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The code to run in the ipython interpreter." - } - }, - "required": ["code"] - } - } -} - -WEATHER_TOOL = { - "type":"function", - "function":{ - "name":"get_current_weather", - "description":"Get the current weather in a given location", - "parameters":{ - "type":"object", - "properties":{ - "location":{ - "type":"string", - "description":"The city and state, e.g. San Francisco, CA" - } - }, - "required":["location"] - } - } -} - -@pytest.mark.parametrize("template_name,tool,argument_key", [ - ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), - ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), - ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), - ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), - ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), -]) -def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None): - n_predict = 512 - global server - # server = ServerPreset.stories15m_moe() - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' - server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "Write an example"}, - ], - "tool_choice": "required", - "tools": [tool], - "parallel_tool_calls": False, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, - }) - assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" - choice = res.body["choices"][0] - tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' - tool_call = tool_calls[0] - expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] - assert expected_function_name == tool_call["function"]["name"] - actual_arguments = tool_call["function"]["arguments"] - assert isinstance(actual_arguments, str) - if argument_key is not None: - actual_arguments = json.loads(actual_arguments) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" - - -@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ - ("meetkai-functionary-medium-v3.1", 128, [], None), - ("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'), - ("meetkai-functionary-medium-v3.2", 128, [], None), - ("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'), -]) -def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - global server - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' - server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "say hello world with python"}, - ], - "tools": tools if tools else None, - "tool_choice": tool_choice, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, - }) - assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" - choice = res.body["choices"][0] - assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' - - -@pytest.mark.slow -@pytest.mark.parametrize("hf_repo,hf_file,template_override", [ - ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - ("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - ("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - ("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - ("NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - ("NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - ("bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), - ("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), - ("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), -]) -def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): - global server - server.n_slots = 1 - server.jinja = True - server.n_ctx = 8192 - server.n_predict = 128 - server.model_hf_repo = hf_repo - server.model_hf_file = hf_file - if template_override: - (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." - server.start(timeout_seconds=15*60) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 256, - "messages": [ - {"role": "user", "content": "What is the weather in Istanbul?"}, - ], - "tools": [WEATHER_TOOL], - }) - assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" - choice = res.body["choices"][0] - tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' - tool_call = tool_calls[0] - assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] - actual_arguments = json.loads(tool_call["function"]["arguments"]) - assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" - location = actual_arguments["location"] - assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" - assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' - - -@pytest.mark.slow -@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [ - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), - (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), - (None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), - (None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), - (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), -]) -def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): - global server - server.n_slots = 1 - server.jinja = True - server.n_ctx = 8192 - server.n_predict = 128 - server.model_hf_repo = hf_repo - server.model_hf_file = hf_file - if template_override: - (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." - server.start(timeout_seconds=15*60) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 256, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "say hello world with python"}, - # {"role": "user", "content": "Print a hello world message with python"}, - ], - "tools": [PYTHON_TOOL], - }) - assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" - choice = res.body["choices"][0] - tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' - tool_call = tool_calls[0] - assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] - actual_arguments = tool_call["function"]["arguments"] - if expected_arguments is not None: - assert actual_arguments == expected_arguments - else: - actual_arguments = json.loads(actual_arguments) - assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" - code = actual_arguments["code"] - assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" - assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' - - def test_logprobs(): global server server.start() diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py new file mode 100644 index 000000000..0cd1d4923 --- /dev/null +++ b/examples/server/tests/unit/test_tool_call.py @@ -0,0 +1,317 @@ +import pytest +from openai import OpenAI +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +TEST_TOOL = { + "type":"function", + "function": { + "name": "test", + "description": "", + "parameters": { + "type": "object", + "properties": { + "success": {"type": "boolean", "const": True}, + }, + "required": ["success"] + } + } +} + +PYTHON_TOOL = { + "type": "function", + "function": { + "name": "python", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } +} + +WEATHER_TOOL = { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" + } + }, + "required":["location"] + } + } +} + + +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + # TODO: fix these + # ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + # ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): + n_predict = 512 + global server + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +@pytest.mark.slow +@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [ + (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (TEST_TOOL, "success", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (TEST_TOOL, "success", "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), + (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), + # TODO: fix these + # (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + # (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), +]) +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): + n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = hf_file + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.1", 128, [], None), + ("meetkai-functionary-medium-v3.1", 128, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 128, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.2", 128, [], None), + ("meetkai-functionary-medium-v3.2", 128, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 128, [PYTHON_TOOL], 'none'), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [], None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + +@pytest.mark.slow +@pytest.mark.parametrize("hf_repo,hf_file,template_override", [ + ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + ("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + ("NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + ("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), +]) +def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 512 + server.model_hf_repo = hf_repo + server.model_hf_file = hf_file + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=15*60) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + # {"role": "system", "content": "Use tools as appropriate."}, + {"role": "user", "content": "What is the weather in Istanbul?"}, + ], + "tools": [WEATHER_TOOL], + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + location = actual_arguments["location"] + assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + + +@pytest.mark.slow +@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [ + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), + (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)), + (None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), + (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None), + (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None), +]) +def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 128 + server.model_hf_repo = hf_repo + server.model_hf_file = hf_file + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=15*60) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + # {"role": "user", "content": "Print a hello world message with python"}, + ], + "tools": [PYTHON_TOOL], + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + if expected_arguments is not None: + assert actual_arguments == expected_arguments + else: + actual_arguments = json.loads(actual_arguments) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'