Greedy sampling in tool call tests

This commit is contained in:
Olivier Chafik 2025-01-22 11:41:43 +00:00
parent cce1166b37
commit c6a22edc57
2 changed files with 22 additions and 18 deletions

View file

@ -687,11 +687,10 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}
common_chat_msg parsed_tool_calls;
json tool_calls;
json message_content;
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
message_content = parsed_tool_calls.content;
@ -716,7 +715,7 @@ struct server_task_result_cmpl_final : server_task_result {
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", {
{"message", json {
{"content", message_content},
{"tool_calls", tool_calls},
{"role", "assistant"},

View file

@ -225,19 +225,19 @@ CODE_INTEPRETER_TOOL = {
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and c"} ),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "It's a spector."} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes, you can."} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spector."} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a spectork."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a speciachy!"} ),
])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server
@ -254,6 +254,9 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
@ -290,6 +293,9 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
],
"tools": tools if tools else None,
"tool_choice": tool_choice,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
@ -339,7 +345,6 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
{"role": "user", "content": "say hello world with python"},
],
"tools": [tool],
# Greedy sampling
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,