From c6a22edc57127a1a61378d09ba60e5403971dd0b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 11:41:43 +0000 Subject: [PATCH] Greedy sampling in tool call tests --- examples/server/server.cpp | 5 ++- .../server/tests/unit/test_chat_completion.py | 35 +++++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d64f025fd..e32577480 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -687,11 +687,10 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - common_chat_msg parsed_tool_calls; json tool_calls; json message_content; if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) { - parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); + auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); if (!parsed_tool_calls.tool_calls.empty()) { finish_reason = "tool_calls"; message_content = parsed_tool_calls.content; @@ -716,7 +715,7 @@ struct server_task_result_cmpl_final : server_task_result { json choice { {"finish_reason", finish_reason}, {"index", 0}, - {"message", { + {"message", json { {"content", message_content}, {"tool_calls", tool_calls}, {"role", "assistant"}, diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index aeba6374d..2c9f5816c 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -224,20 +224,20 @@ CODE_INTEPRETER_TOOL = { @pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ - ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), - ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ), - ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ), - ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ), - ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), - ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ), + ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), + ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and c"} ), + ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ), + ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "It's a spector."} ), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spector."} ), + ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), + ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spectork."} ), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a speciachy!"} ), ]) def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): global server @@ -254,6 +254,9 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: "tool_choice": "required", "tools": [tool], "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -290,6 +293,9 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: ], "tools": tools if tools else None, "tool_choice": tool_choice, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -339,7 +345,6 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st {"role": "user", "content": "say hello world with python"}, ], "tools": [tool], - # Greedy sampling "temperature": 0.0, "top_k": 1, "top_p": 1.0,