Greedy sampling in tool call tests

This commit is contained in:
Olivier Chafik 2025-01-22 11:41:43 +00:00
parent cce1166b37
commit c6a22edc57
2 changed files with 22 additions and 18 deletions

View file

@ -687,11 +687,10 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
common_chat_msg parsed_tool_calls;
json tool_calls; json tool_calls;
json message_content; json message_content;
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) { if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) { if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls"; finish_reason = "tool_calls";
message_content = parsed_tool_calls.content; message_content = parsed_tool_calls.content;
@ -716,7 +715,7 @@ struct server_task_result_cmpl_final : server_task_result {
json choice { json choice {
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", { {"message", json {
{"content", message_content}, {"content", message_content},
{"tool_calls", tool_calls}, {"tool_calls", tool_calls},
{"role", "assistant"}, {"role", "assistant"},

View file

@ -224,20 +224,20 @@ CODE_INTEPRETER_TOOL = {
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ @pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ), ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and c"} ),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ), ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ), ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "It's a spector."} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spector."} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spectork."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a speciachy!"} ),
]) ])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server global server
@ -254,6 +254,9 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
"tool_choice": "required", "tool_choice": "required",
"tools": [tool], "tools": [tool],
"parallel_tool_calls": False, "parallel_tool_calls": False,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
@ -290,6 +293,9 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
], ],
"tools": tools if tools else None, "tools": tools if tools else None,
"tool_choice": tool_choice, "tool_choice": tool_choice,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]
@ -339,7 +345,6 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
{"role": "user", "content": "say hello world with python"}, {"role": "user", "content": "say hello world with python"},
], ],
"tools": [tool], "tools": [tool],
# Greedy sampling
"temperature": 0.0, "temperature": 0.0,
"top_k": 1, "top_k": 1,
"top_p": 1.0, "top_p": 1.0,