tool-call: fix lazy grammar & mixed content + tool calls parsing
This commit is contained in:
parent
2efa0c27bf
commit
57f40e366b
3 changed files with 25 additions and 33 deletions
|
@ -89,7 +89,8 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) {
|
||||||
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str());
|
||||||
result.content += std::string(it, rit->suffix().first);
|
result.content += std::string(it, rit->suffix().first);
|
||||||
break;
|
it = rit->suffix().first;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
result.content += std::string(it, rit->prefix().second);
|
result.content += std::string(it, rit->prefix().second);
|
||||||
|
|
|
@ -325,20 +325,18 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(None, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
("NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||||
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
("NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
("bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||||
# TODO: fix these models
|
("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
# (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
|
||||||
])
|
])
|
||||||
def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
|
@ -357,9 +355,6 @@ def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file
|
||||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||||
],
|
],
|
||||||
"tools": [WEATHER_TOOL],
|
"tools": [WEATHER_TOOL],
|
||||||
# "temperature": 0.5,
|
|
||||||
# "top_k": 10,
|
|
||||||
# "top_p": 0.9,
|
|
||||||
})
|
})
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||||
choice = res.body["choices"][0]
|
choice = res.body["choices"][0]
|
||||||
|
@ -367,19 +362,17 @@ def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file
|
||||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||||
actual_arguments = tool_call["function"]["arguments"]
|
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
if expected_arguments is not None:
|
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||||
assert actual_arguments == expected_arguments
|
location = actual_arguments["location"]
|
||||||
else:
|
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||||
actual_arguments = json.loads(actual_arguments)
|
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
|
||||||
location = actual_arguments["location"]
|
|
||||||
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
|
||||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
||||||
|
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
|
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||||
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||||
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||||
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||||
|
@ -387,10 +380,7 @@ def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file
|
||||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||||
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
|
||||||
# TODO: fix these models
|
|
||||||
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
|
||||||
])
|
])
|
||||||
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
|
@ -413,9 +403,6 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
# {"role": "user", "content": "Print a hello world message with python"},
|
# {"role": "user", "content": "Print a hello world message with python"},
|
||||||
],
|
],
|
||||||
"tools": [PYTHON_TOOL],
|
"tools": [PYTHON_TOOL],
|
||||||
"temperature": 0.5,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
})
|
})
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||||
choice = res.body["choices"][0]
|
choice = res.body["choices"][0]
|
||||||
|
|
|
@ -1116,6 +1116,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||||
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(grammar.vocab != nullptr);
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
|
if (grammar.awaiting_trigger) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
bool allow_eog = false;
|
bool allow_eog = false;
|
||||||
for (const auto & stack : grammar.stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue