reshuffle chat handlers
This commit is contained in:
parent
43385b2ff2
commit
5ec4c5e4d3
3 changed files with 318 additions and 257 deletions
|
@ -76,7 +76,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
if (type == "function") {
|
if (type == "function") {
|
||||||
tool_names.push_back(tool["function"]["name"]);
|
tool_names.push_back(tool["function"]["name"]);
|
||||||
} else if (type == "code_interpreter") {
|
} else if (type == "code_interpreter") {
|
||||||
tool_names.push_back("ipython");
|
tool_names.push_back("python");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -171,6 +171,10 @@ public:
|
||||||
/* .tool_calls = */ {},
|
/* .tool_calls = */ {},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<common_chat_parser> clone() const override {
|
||||||
|
return std::make_unique<text_chat_parser>();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class monolithic_chat_parser : public common_chat_parser {
|
class monolithic_chat_parser : public common_chat_parser {
|
||||||
|
@ -192,13 +196,48 @@ public:
|
||||||
input_buffer_.clear();
|
input_buffer_.clear();
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<common_chat_parser> clone() const override {
|
||||||
|
return std::make_unique<monolithic_chat_parser>(parse_final_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
const auto python_tool = json::parse(R"({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "python",
|
||||||
|
"description": "an ipython interpreter",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Python code to execute."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})");
|
||||||
|
|
||||||
|
static void foreach_normalized_tool(const json & tools, const std::function<void(const json &)> & fn) {
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
fn(python_tool);
|
||||||
|
} else {
|
||||||
|
fn(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
for (const auto & tool : params.tools) {
|
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
auto tool_schema = json {
|
auto tool_schema = json {
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -222,7 +261,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
|
||||||
tool_schema["required"].push_back("id");
|
tool_schema["required"].push_back("id");
|
||||||
}
|
}
|
||||||
tool_call_schemas.emplace_back(tool_schema);
|
tool_call_schemas.emplace_back(tool_schema);
|
||||||
}
|
});
|
||||||
const auto tool_call =
|
const auto tool_call =
|
||||||
params.parallel_tool_calls
|
params.parallel_tool_calls
|
||||||
? json {
|
? json {
|
||||||
|
@ -276,7 +315,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
|
||||||
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
"Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```");
|
||||||
|
|
||||||
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
|
||||||
json data = json::parse(input);
|
json data = json::parse(input);
|
||||||
common_chat_msg result;
|
common_chat_msg result;
|
||||||
result.role = "assistant";
|
result.role = "assistant";
|
||||||
|
@ -303,13 +342,11 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
for (const auto & tool : params.tools) {
|
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -329,7 +366,7 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
|
||||||
}},
|
}},
|
||||||
{"required", json::array({"name", "arguments", "id"})},
|
{"required", json::array({"name", "arguments", "id"})},
|
||||||
});
|
});
|
||||||
}
|
});
|
||||||
auto schema = json {
|
auto schema = json {
|
||||||
{"type", "array"},
|
{"type", "array"},
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
|
@ -344,24 +381,14 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
|
||||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
for (const auto & tool : params.tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (tool["type"] == "code_interpreter") {
|
|
||||||
builtin_tools.push_back("code_interpreter");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -375,6 +402,7 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tool["type"] == "code_interpreter") {
|
if (tool["type"] == "code_interpreter") {
|
||||||
|
builtin_tools.push_back("code_interpreter");
|
||||||
has_python = true;
|
has_python = true;
|
||||||
} else if (tool["type"] == "function" && tool.contains("function")) {
|
} else if (tool["type"] == "function" && tool.contains("function")) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
|
@ -422,8 +450,10 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
|
{"builtin_tools", builtin_tools},
|
||||||
|
});
|
||||||
|
data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
|
||||||
if (uses_python_tag) {
|
if (uses_python_tag) {
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
@ -448,11 +478,11 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
for (const auto & tool : params.tools) {
|
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
schemas.push_back({
|
schemas.push_back({
|
||||||
{"type", "object"},
|
{"type", "object"},
|
||||||
|
@ -465,7 +495,7 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
|
||||||
}},
|
}},
|
||||||
{"required", json::array({"name", "arguments", "id"})},
|
{"required", json::array({"name", "arguments", "id"})},
|
||||||
});
|
});
|
||||||
}
|
});
|
||||||
auto schema = json {
|
auto schema = json {
|
||||||
{"type", "array"},
|
{"type", "array"},
|
||||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||||
|
@ -480,13 +510,13 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
|
||||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||||
}
|
}
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
|
||||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||||
});
|
});
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
@ -530,7 +560,7 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||||
static std::regex close_regex(R"($|(?=>>>))");
|
static std::regex close_regex(R"($|(?=>>>))");
|
||||||
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
|
@ -538,11 +568,12 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
// TODO: handle tool {type: code_interpreter} as python
|
// TODO: handle tool {type: code_interpreter} as python
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
@ -578,7 +609,7 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
|
||||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||||
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
|
@ -602,12 +633,12 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
for (const auto & tool : params.tools) {
|
foreach_normalized_tool(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
auto parameters = function["parameters"];
|
auto parameters = function["parameters"];
|
||||||
|
@ -620,8 +651,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
|
||||||
}},
|
}},
|
||||||
{"required", json::array({"name", "arguments"})},
|
{"required", json::array({"name", "arguments"})},
|
||||||
}));
|
}));
|
||||||
}
|
});
|
||||||
|
|
||||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||||
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (params.tool_choice != "required") {
|
if (params.tool_choice != "required") {
|
||||||
|
@ -630,7 +660,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
|
||||||
try {
|
try {
|
||||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||||
|
@ -677,24 +707,40 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
if (params.tools.is_null()) {
|
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
|
||||||
data.handler = std::make_unique<text_chat_parser>();
|
data.parser = std::make_unique<text_chat_parser>();
|
||||||
|
if (!params.json_schema.is_null()) {
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||||
|
}
|
||||||
|
data.grammar = json_schema_to_grammar(params.json_schema);
|
||||||
|
} else {
|
||||||
|
data.grammar = params.grammar.empty();
|
||||||
|
}
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
const auto & src = tmpl.source();
|
|
||||||
|
|
||||||
|
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
|
if (params.tools.is_null()) {
|
||||||
|
return common_chat_init_without_tools(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
throw std::runtime_error("Cannot specify grammar with tools");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & src = tmpl.source();
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
return build_hermes_2_pro_tool_call_handler(tmpl, params);
|
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find(">>>all") != std::string::npos) {
|
if (src.find(">>>all") != std::string::npos) {
|
||||||
return build_functionary_v3_llama_3_tool_call_handler(tmpl, params);
|
return common_chat_init_functionary_v3_llama_3_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return build_functionary_v3_llama_3_1_tool_call_handler(tmpl, params);
|
return common_chat_init_functionary_v3_llama_3_1_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
|
||||||
|
@ -705,16 +751,16 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
|
||||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||||
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||||
|
|
||||||
return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
return common_chat_init_llama_3_tool_calls(tmpl, params, uses_python_tag, eagerly_match_any_json);
|
||||||
}
|
}
|
||||||
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
// if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
// TODO: Command-R-Plus
|
// TODO: Command-R-Plus
|
||||||
// }
|
// }
|
||||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
return build_mistral_nemo_tool_call_handler(tmpl, params);
|
return common_chat_init_mistral_nemo_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
if (src.find(" functools[") != std::string::npos) {
|
if (src.find(" functools[") != std::string::npos) {
|
||||||
return build_firefunction_v2_tool_call_handler(tmpl, params);
|
return common_chat_init_firefunction_v2_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
return build_generic_tool_call_handler(tmpl, params);
|
return common_chat_init_generic_tool_call(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ struct common_chat_params {
|
||||||
json json_schema;
|
json json_schema;
|
||||||
bool parallel_tool_calls;
|
bool parallel_tool_calls;
|
||||||
bool stream;
|
bool stream;
|
||||||
|
std::string grammar;
|
||||||
};
|
};
|
||||||
|
|
||||||
class common_chat_parser {
|
class common_chat_parser {
|
||||||
|
@ -30,14 +31,15 @@ public:
|
||||||
|
|
||||||
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
|
virtual std::optional<common_chat_msg> parse_partial(const std::string & input) = 0;
|
||||||
virtual common_chat_msg parse_final(const std::string & input) = 0;
|
virtual common_chat_msg parse_final(const std::string & input) = 0;
|
||||||
|
virtual std::unique_ptr<common_chat_parser> clone() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_data {
|
struct common_chat_data {
|
||||||
std::string prompt;
|
json prompt;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
std::vector<std::string> additional_stops;
|
std::vector<std::string> additional_stops;
|
||||||
std::unique_ptr<class common_chat_parser> handler;
|
std::unique_ptr<class common_chat_parser> parser;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "chat-handler.hpp"
|
#include "chat-handler.hpp"
|
||||||
|
#include "chat-template.hpp"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
|
||||||
|
@ -71,9 +72,7 @@ static std::string dump(const json & j) {
|
||||||
return minja::Value(j).dump(-1, /* to_json= */ true);
|
return minja::Value(j).dump(-1, /* to_json= */ true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_parse_tool_call(common_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) {
|
||||||
std::cout << "# Testing: " << input << std::endl << std::flush;
|
|
||||||
auto result = parse_tool_calls(style, tools, input);
|
|
||||||
assert_equals(expected_content, result.content);
|
assert_equals(expected_content, result.content);
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tc : result.tool_calls) {
|
for (const auto & tc : result.tool_calls) {
|
||||||
|
@ -140,209 +139,216 @@ const auto code_interpreter_tool = json::parse(R"({
|
||||||
})");
|
})");
|
||||||
const json tools = {special_function_tool, code_interpreter_tool};
|
const json tools = {special_function_tool, code_interpreter_tool};
|
||||||
|
|
||||||
static void test_parsing() {
|
// static void test_parsing() {
|
||||||
json request = {
|
// json request = {
|
||||||
{"tools", tools}
|
// {"tools", tools}
|
||||||
};
|
// };
|
||||||
|
|
||||||
const auto fooBarCall = json {
|
// const auto fooBarCall = json {
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "foo"},
|
// {"name", "foo"},
|
||||||
{"arguments", dump({
|
// {"arguments", dump({
|
||||||
{"bar", 1}
|
// {"bar", 1}
|
||||||
})},
|
// })},
|
||||||
}}
|
// }}
|
||||||
};
|
// };
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
||||||
"{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
|
// "{\"tool_call\": {\"name\": \"foo\", \"arguments\": {\"bar\": 1}}}",
|
||||||
"",
|
// "",
|
||||||
json::array({fooBarCall}));
|
// json::array({fooBarCall}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_GENERIC, tools,
|
||||||
"{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
|
// "{\"tool_calls\": [{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}]}",
|
||||||
"",
|
// "",
|
||||||
json::array({fooBarCall}));
|
// json::array({fooBarCall}));
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, tools,
|
||||||
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
// "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
||||||
"",
|
// "",
|
||||||
json::array({fooBarCall}));
|
// json::array({fooBarCall}));
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
||||||
">>>python\n{\"code\": \"print('Hello, world!')\"}",
|
// ">>>python\n{\"code\": \"print('Hello, world!')\"}",
|
||||||
"",
|
// "",
|
||||||
json {{
|
// json {{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "python"},
|
// {"name", "python"},
|
||||||
{"arguments", dump({
|
// {"arguments", dump({
|
||||||
{"code", "print('Hello, world!')"}
|
// {"code", "print('Hello, world!')"}
|
||||||
})}
|
// })}
|
||||||
}}
|
// }}
|
||||||
}});
|
// }});
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, tools,
|
||||||
">>>special_function\n{\"arg1\": 1}\n ",
|
// ">>>special_function\n{\"arg1\": 1}\n ",
|
||||||
"",
|
// "",
|
||||||
json {{
|
// json {{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "special_function"},
|
// {"name", "special_function"},
|
||||||
{"arguments", dump({
|
// {"arguments", dump({
|
||||||
{"arg1", 1}
|
// {"arg1", 1}
|
||||||
})}
|
// })}
|
||||||
}}
|
// }}
|
||||||
}});
|
// }});
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
||||||
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
// "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
||||||
"Hello, world!",
|
// "Hello, world!",
|
||||||
json {
|
// json {
|
||||||
{
|
// {
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "foo"},
|
// {"name", "foo"},
|
||||||
{"arguments", dump({
|
// {"arguments", dump({
|
||||||
{"arg1", 1}
|
// {"arg1", 1}
|
||||||
})}
|
// })}
|
||||||
}}
|
// }}
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "bar"},
|
// {"name", "bar"},
|
||||||
{"arguments", dump({
|
// {"arguments", dump({
|
||||||
{"arg2", 2}
|
// {"arg2", 2}
|
||||||
})}
|
// })}
|
||||||
}}
|
// }}
|
||||||
},
|
// },
|
||||||
});
|
// });
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, tools,
|
||||||
"<function=test>{ } </function> ",
|
// "<function=test>{ } </function> ",
|
||||||
" ",
|
// " ",
|
||||||
json {{
|
// json {{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "test"},
|
// {"name", "test"},
|
||||||
{"arguments", "{}"}
|
// {"arguments", "{}"}
|
||||||
}}
|
// }}
|
||||||
}});
|
// }});
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"<|python_tag|>this could be anything",
|
// "<|python_tag|>this could be anything",
|
||||||
"",
|
// "",
|
||||||
json {{
|
// json {{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "python"},
|
// {"name", "python"},
|
||||||
{"arguments", "this could be anything"},
|
// {"arguments", "this could be anything"},
|
||||||
}}
|
// }}
|
||||||
}});
|
// }});
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"I'm thinking<|python_tag|>",
|
// "I'm thinking<|python_tag|>",
|
||||||
"I'm thinking",
|
// "I'm thinking",
|
||||||
json {{
|
// json {{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"name", "python"},
|
// {"name", "python"},
|
||||||
{"arguments", ""},
|
// {"arguments", ""},
|
||||||
}}
|
// }}
|
||||||
}});
|
// }});
|
||||||
auto special_function_call = json {
|
// auto special_function_call = json {
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"arguments", dump({{"arg1", 1}})},
|
// {"arguments", dump({{"arg1", 1}})},
|
||||||
{"name", "special_function"},
|
// {"name", "special_function"},
|
||||||
}},
|
// }},
|
||||||
};
|
// };
|
||||||
auto special_function_call_with_id = json::parse(special_function_call.dump());
|
// auto special_function_call_with_id = json::parse(special_function_call.dump());
|
||||||
special_function_call_with_id["id"] = "123456789";
|
// special_function_call_with_id["id"] = "123456789";
|
||||||
|
|
||||||
auto no_function_call = json::array();
|
// auto no_function_call = json::array();
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
|
// "{\"name\": \"python\", \"parameters\": {\"code\": \"print('Hey')\"}}",
|
||||||
"",
|
// "",
|
||||||
json::array({{
|
// json::array({{
|
||||||
{"type", "function"},
|
// {"type", "function"},
|
||||||
{"function", {
|
// {"function", {
|
||||||
{"arguments", dump({{"code", "print('Hey')"}})},
|
// {"arguments", dump({{"code", "print('Hey')"}})},
|
||||||
{"name", "python"},
|
// {"name", "python"},
|
||||||
}}
|
// }}
|
||||||
}}));
|
// }}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
// "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
// "",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
// "",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
// "{\n\t\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
// "",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
// "{\n \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
// "",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
// "{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
// "",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
|
|
||||||
// No match: function unknown
|
// // No match: function unknown
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
no_function_call);
|
// no_function_call);
|
||||||
// No match: bad indentation
|
// // No match: bad indentation
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\n\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
no_function_call);
|
// no_function_call);
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools,
|
||||||
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
// "{\n \"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
no_function_call);
|
// no_function_call);
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, tools,
|
||||||
"Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
// "Bleh[TOOL_CALLS][{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\", \"id\": \"123456789\"}]",
|
||||||
"Bleh",
|
// "Bleh",
|
||||||
json::array({special_function_call_with_id}));
|
// json::array({special_function_call_with_id}));
|
||||||
|
|
||||||
test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
|
// test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, tools,
|
||||||
"Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
|
// "Bleh functools[{\"arguments\": {\"arg1\": 1}, \"name\": \"special_function\"}]",
|
||||||
"Bleh",
|
// "Bleh",
|
||||||
json::array({special_function_call}));
|
// json::array({special_function_call}));
|
||||||
}
|
// }
|
||||||
|
|
||||||
static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) {
|
// static void test_tool_call_style(const std::string & template_file, common_tool_call_style expected) {
|
||||||
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||||
auto tool_call_style = common_tool_call_style_detect(tmpl);
|
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||||
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
// std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||||
assert_equals(expected, tool_call_style);
|
// assert_equals(expected, tool_call_style);
|
||||||
}
|
// }
|
||||||
|
|
||||||
static void test_tool_call_style_detection() {
|
// static void test_tool_call_style_detection() {
|
||||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1);
|
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1);
|
||||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3);
|
// test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3);
|
||||||
test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2);
|
// test_tool_call_style("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2);
|
||||||
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1);
|
// test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_1);
|
||||||
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2);
|
// test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_LLAMA_3_2);
|
||||||
test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
// test_tool_call_style("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||||
test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||||
test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
// test_tool_call_style("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", COMMON_TOOL_CALL_STYLE_HERMES_2_PRO);
|
||||||
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS);
|
// test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS);
|
||||||
test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO);
|
// test_tool_call_style("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO);
|
||||||
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC);
|
// test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", COMMON_TOOL_CALL_STYLE_GENERIC);
|
||||||
}
|
// }
|
||||||
|
|
||||||
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||||
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
|
||||||
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
|
||||||
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
|
||||||
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
common_chat_params params;
|
||||||
|
params.parallel_tool_calls = true;
|
||||||
|
params.messages = json::array();
|
||||||
|
params.messages.push_back(user_message);
|
||||||
|
params.tools = tools;
|
||||||
|
std::string prefix = common_chat_init(tmpl, params).prompt;
|
||||||
|
params.messages.push_back(delta_message);
|
||||||
|
std::string full = common_chat_init(tmpl, params).prompt;
|
||||||
|
|
||||||
// Check full starts with prefix
|
// Check full starts with prefix
|
||||||
if (full.find(prefix) != 0) {
|
if (full.find(prefix) != 0) {
|
||||||
|
@ -364,7 +370,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
|
||||||
auto tool_call_style = common_tool_call_style_detect(tmpl);
|
// auto tool_call_style = common_tool_call_style_detect(tmpl);
|
||||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||||
|
|
||||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||||
|
@ -374,8 +380,13 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
{"content", "Hello, world!"}
|
{"content", "Hello, world!"}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto handler = common_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
|
common_chat_params params;
|
||||||
auto grammar = build_grammar(handler.grammar);
|
params.parallel_tool_calls = true;
|
||||||
|
params.messages = json {user_message, tool_calling_message};
|
||||||
|
params.tools = tools;
|
||||||
|
auto chat_data = common_chat_init(tmpl, params);
|
||||||
|
fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
|
||||||
|
auto grammar = build_grammar(chat_data.grammar);
|
||||||
if (!grammar) {
|
if (!grammar) {
|
||||||
throw std::runtime_error("Failed to build grammar");
|
throw std::runtime_error("Failed to build grammar");
|
||||||
}
|
}
|
||||||
|
@ -383,7 +394,9 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
if (!skip_grammar_test) {
|
if (!skip_grammar_test) {
|
||||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
||||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||||
test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
|
|
||||||
|
const auto msg = chat_data.parser->parse_final(full_delta);
|
||||||
|
assert_msg_equals(msg, "", tool_calls);
|
||||||
|
|
||||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||||
{"role", "assistant"},
|
{"role", "assistant"},
|
||||||
|
@ -391,7 +404,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
||||||
{"tool_calls", tool_calls}
|
{"tool_calls", tool_calls}
|
||||||
}, tools);
|
}, tools);
|
||||||
if (!match_string(content_less_delta, grammar.get())) {
|
if (!match_string(content_less_delta, grammar.get())) {
|
||||||
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + handler.grammar);
|
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -423,10 +436,10 @@ static void test_grammars() {
|
||||||
}}}
|
}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
|
||||||
// }
|
}
|
||||||
// {
|
// {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// assert_equals(tmpl.requires_object_arguments_, true);
|
// assert_equals(tmpl.requires_object_arguments_, true);
|
||||||
|
@ -457,14 +470,14 @@ static void test_grammars() {
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
// }
|
// }
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
// }
|
}
|
||||||
// {
|
{
|
||||||
// const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
|
||||||
// }
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
|
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue