follow enum naming style for tool call styles

This commit is contained in:
Olivier Chafik 2025-01-22 02:13:02 +00:00
parent 5268ec8947
commit 9e8b43f993
5 changed files with 86 additions and 86 deletions

View file

@ -49,25 +49,25 @@ static json normalize_tools(const json & tools) {
std::string common_tool_call_style_name(common_tool_call_style style) { std::string common_tool_call_style_name(common_tool_call_style style) {
switch (style) { switch (style) {
case common_tool_call_style::None: case COMMON_TOOL_CALL_STYLE_NONE:
return "None"; return "None";
case common_tool_call_style::Generic: case COMMON_TOOL_CALL_STYLE_GENERIC:
return "Generic"; return "Generic";
case common_tool_call_style::Llama31: case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
return "Llama-3.1"; return "Llama-3.1";
case common_tool_call_style::Llama32: case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
return "Llama-3.2"; return "Llama-3.2";
case common_tool_call_style::FunctionaryV3Llama3: case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
return "FunctionaryV3Llama3"; return "FunctionaryV3Llama3";
case common_tool_call_style::FunctionaryV3Llama31: case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
return "FunctionaryV3Llama3.1"; return "FunctionaryV3Llama3.1";
case common_tool_call_style::Hermes2Pro: case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
return "Hermes2Pro"; return "Hermes2Pro";
case common_tool_call_style::CommandRPlus: case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS:
return "CommandRPlus"; return "CommandRPlus";
case common_tool_call_style::MistralNemo: case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
return "MistralNemo"; return "MistralNemo";
case common_tool_call_style::FirefunctionV2: case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
return "FirefunctionV2"; return "FirefunctionV2";
default: default:
return "Unknown"; return "Unknown";
@ -78,26 +78,26 @@ common_tool_call_style common_tool_call_style_detect(const common_chat_template
const auto & src = chat_template.source(); const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) { if (src.find("<tool_call>") != std::string::npos) {
return Hermes2Pro; return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO;
} else if (src.find(">>>all") != std::string::npos) { } else if (src.find(">>>all") != std::string::npos) {
return FunctionaryV3Llama3; return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3;
} else if (src.find("<|start_header_id|>") != std::string::npos } else if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return FunctionaryV3Llama31; return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1;
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { } else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (src.find("<|python_tag|>") != std::string::npos) { if (src.find("<|python_tag|>") != std::string::npos) {
return Llama31; return COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
} else { } else {
return Llama32; return COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
} }
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
return CommandRPlus; return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS;
} else if (src.find("[TOOL_CALLS]") != std::string::npos) { } else if (src.find("[TOOL_CALLS]") != std::string::npos) {
return MistralNemo; return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO;
} else if (src.find(" functools[") != std::string::npos) { } else if (src.find(" functools[") != std::string::npos) {
return FirefunctionV2; return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2;
} else { } else {
return Generic; return COMMON_TOOL_CALL_STYLE_GENERIC;
} }
} }
@ -356,23 +356,23 @@ static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& inp
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str()); fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
switch (style) { switch (style) {
case common_tool_call_style::None: case COMMON_TOOL_CALL_STYLE_NONE:
return {input, {}}; return {input, {}};
case common_tool_call_style::Generic: case COMMON_TOOL_CALL_STYLE_GENERIC:
return parse_generic_tool_calls(input); return parse_generic_tool_calls(input);
case common_tool_call_style::Llama31: case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true); return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
case common_tool_call_style::Llama32: case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false); return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
case common_tool_call_style::FunctionaryV3Llama3: case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
return parse_functionary_v3_tool_calls(tools, input); return parse_functionary_v3_tool_calls(tools, input);
case common_tool_call_style::FunctionaryV3Llama31: case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
return parse_functionary_v3_llama_3_1_tool_calls(tools, input); return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
case common_tool_call_style::Hermes2Pro: case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
return parse_hermes_tool_calls(input); return parse_hermes_tool_calls(input);
case common_tool_call_style::MistralNemo: case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
return parse_mistral_nemo_tool_calls(input); return parse_mistral_nemo_tool_calls(input);
case common_tool_call_style::FirefunctionV2: case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
return parse_firefunction_v2_tool_calls(input); return parse_firefunction_v2_tool_calls(input);
default: default:
throw std::runtime_error("Unsupported tool call style"); throw std::runtime_error("Unsupported tool call style");
@ -410,10 +410,10 @@ common_tool_call_handler common_tool_call_handler_init(
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>(); auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
switch (style) { switch (style) {
case common_tool_call_style::None: case COMMON_TOOL_CALL_STYLE_NONE:
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break; break;
case common_tool_call_style::Generic: { case COMMON_TOOL_CALL_STYLE_GENERIC: {
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
@ -493,7 +493,7 @@ common_tool_call_handler common_tool_call_handler_init(
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break; break;
} }
case common_tool_call_style::MistralNemo: { case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
@ -534,7 +534,7 @@ common_tool_call_handler common_tool_call_handler_init(
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break; break;
} }
case common_tool_call_style::FirefunctionV2: { case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
@ -568,8 +568,8 @@ common_tool_call_handler common_tool_call_handler_init(
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break; break;
} }
case common_tool_call_style::Llama31: case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
case common_tool_call_style::Llama32: { case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: {
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
for (const auto & tool : tools) { for (const auto & tool : tools) {
if (!tool.contains("type")) { if (!tool.contains("type")) {
@ -582,13 +582,13 @@ common_tool_call_handler common_tool_call_handler_init(
} }
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
auto uses_python_tag = style == common_tool_call_style::Llama31; auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
// as it seems to be outputting some JSON. // as it seems to be outputting some JSON.
// TODO: make this conditional on a very small model (e.g. 1B / 3B). // TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = style == common_tool_call_style::Llama32; auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
@ -639,7 +639,7 @@ common_tool_call_handler common_tool_call_handler_init(
}); });
break; break;
} }
case common_tool_call_style::FunctionaryV3Llama3: { case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
@ -670,7 +670,7 @@ common_tool_call_handler common_tool_call_handler_init(
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
} }
case common_tool_call_style::FunctionaryV3Llama31: { case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python // TODO: handle tool {type: code_interpreter} as python
@ -700,7 +700,7 @@ common_tool_call_handler common_tool_call_handler_init(
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
} }
case common_tool_call_style::Hermes2Pro: { case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: {
// NousResearchHermesPro_2 // NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);

View file

@ -8,17 +8,17 @@
#include "json.hpp" #include "json.hpp"
enum common_tool_call_style { enum common_tool_call_style {
UnknownToolCallStyle, COMMON_TOOL_CALL_STYLE_UNKNOWN,
None, COMMON_TOOL_CALL_STYLE_NONE,
Generic, COMMON_TOOL_CALL_STYLE_GENERIC,
Llama31, COMMON_TOOL_CALL_STYLE_LLAMA_3_1,
Llama32, COMMON_TOOL_CALL_STYLE_LLAMA_3_2,
FunctionaryV3Llama3, COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3,
FunctionaryV3Llama31, COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1,
Hermes2Pro, COMMON_TOOL_CALL_STYLE_HERMES_2_PRO,
CommandRPlus, COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS,
MistralNemo, COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO,
FirefunctionV2, COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
}; };
struct common_tool_call { struct common_tool_call {

View file

@ -118,7 +118,7 @@ struct slot_params {
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
json oaicompat_tools; json oaicompat_tools;
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None; common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
json to_json() const { json to_json() const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
@ -589,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result {
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
json oaicompat_tools; json oaicompat_tools;
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None; common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -690,7 +690,7 @@ struct server_task_result_cmpl_final : server_task_result {
common_tool_calls parsed_tool_calls; common_tool_calls parsed_tool_calls;
json tool_calls; json tool_calls;
json message_content; json message_content;
if (oaicompat_tool_call_style != common_tool_call_style::None && !oaicompat_tools.is_null()) { if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) { if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls"; finish_reason = "tool_calls";
@ -3772,7 +3772,7 @@ int main(int argc, char ** argv) {
std::function<bool()> is_connection_closed, std::function<bool()> is_connection_closed,
httplib::Response & res, httplib::Response & res,
oaicompat_type oaicompat, oaicompat_type oaicompat,
common_tool_call_style tool_call_style = common_tool_call_style::None) { common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {

View file

@ -595,7 +595,7 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Cannot use tools with stream"); throw std::runtime_error("Cannot use tools with stream");
} }
if (use_jinja) { if (use_jinja) {
if (tool_call_style == common_tool_call_style::UnknownToolCallStyle) { if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
} }
} else { } else {

View file

@ -146,21 +146,21 @@ static void test_parsing() {
}} }}
}; };
test_parse_tool_call(common_tool_call_style::Generic, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
"{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}", "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
"", "",
json::array({fooBarCall})); json::array({fooBarCall}));
test_parse_tool_call(common_tool_call_style::Generic, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
"{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}", "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
"", "",
json::array({fooBarCall})); json::array({fooBarCall}));
test_parse_tool_call(common_tool_call_style::Hermes2Pro, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>", "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
"", "",
json::array({fooBarCall})); json::array({fooBarCall}));
test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama3, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
">>>python\n{\"code\": \"print('Hello, world!')\"}", ">>>python\n{\"code\": \"print('Hello, world!')\"}",
"", "",
json {{ json {{
@ -172,7 +172,7 @@ static void test_parsing() {
})} })}
}} }}
}}); }});
test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama3, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
">>>special_function\n{\"arg1\": 1}\n ", ">>>special_function\n{\"arg1\": 1}\n ",
"", "",
json {{ json {{
@ -185,7 +185,7 @@ static void test_parsing() {
}} }}
}}); }});
test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!", "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
"Hello, world!", "Hello, world!",
json { json {
@ -208,7 +208,7 @@ static void test_parsing() {
}} }}
}, },
}); });
test_parse_tool_call(common_tool_call_style::FunctionaryV3Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
"<function=test>{ } </function> ", "<function=test>{ } </function> ",
" ", " ",
json {{ json {{
@ -219,7 +219,7 @@ static void test_parsing() {
}} }}
}}); }});
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"<|python_tag|>this could be anything", "<|python_tag|>this could be anything",
"", "",
json {{ json {{
@ -231,7 +231,7 @@ static void test_parsing() {
})} })}
}} }}
}}); }});
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"I'm thinking<|python_tag|>", "I'm thinking<|python_tag|>",
"I'm thinking", "I'm thinking",
json {{ json {{
@ -253,7 +253,7 @@ static void test_parsing() {
auto no_function_call = json::array(); auto no_function_call = json::array();
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}", "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
"", "",
json::array({{ json::array({{
@ -263,48 +263,48 @@ static void test_parsing() {
{"name", "python"}, {"name", "python"},
}} }}
}})); }}));
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json::array({special_function_call})); json::array({special_function_call}));
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json::array({special_function_call})); json::array({special_function_call}));
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json::array({special_function_call})); json::array({special_function_call}));
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json::array({special_function_call})); json::array({special_function_call}));
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json::array({special_function_call})); json::array({special_function_call}));
// No match: function unknown // No match: function unknown
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call); no_function_call);
// No match: bad indentation // No match: bad indentation
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call); no_function_call);
test_parse_tool_call(common_tool_call_style::Llama31, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
no_function_call); no_function_call);
test_parse_tool_call(common_tool_call_style::MistralNemo, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
"Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]", "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
"Bleh", "Bleh",
json::array({special_function_call_with_id})); json::array({special_function_call_with_id}));
test_parse_tool_call(common_tool_call_style::FirefunctionV2, tools, test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
"Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]", "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
"Bleh", "Bleh",
json::array({special_function_call})); json::array({special_function_call}));
@ -318,17 +318,17 @@ static void test_tool_call_style(const std::string & template_file, common_tool_
} }
static void test_tool_call_style_detection() { static void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3);
test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", FirefunctionV2); test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2);
test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", Hermes2Pro); test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", Hermes2Pro); test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", Hermes2Pro); test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus); test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS);
test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", MistralNemo); test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO);
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic); test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC);
} }
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) { static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {