tool-calls: add generic tool call style as default

This commit is contained in:
ochafik 2024-10-22 10:53:21 +01:00
parent fa8462ffd3
commit 9f5ab97756
3 changed files with 110 additions and 19 deletions

View file

@ -31,7 +31,7 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template &
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
return CommandRPlus;
} else {
return UnknownToolCallStyle;
return Generic;
}
}
@ -212,8 +212,32 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, cons
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
}
static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
json data = json::parse(input);
llama_tool_calls result;
if (data.contains("tool_calls")) {
for (const auto & tool_call : data["tool_calls"]) {
result.tool_calls.push_back({
tool_call["name"],
tool_call["arguments"].dump(),
});
}
} else if (data.contains("tool_call")) {
result.tool_calls.push_back({
data["tool_call"]["name"],
data["tool_call"]["arguments"].dump(),
});
} else if (data.contains("response")) {
const auto & response = data["response"];
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
}
return result;
}
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
switch (style) {
case llama_tool_call_style::Generic:
return parse_generic_tool_calls(input);
case llama_tool_call_style::Llama31:
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
case llama_tool_call_style::Llama32:
@ -235,11 +259,72 @@ llama_tool_call_handler llama_tool_call_handler_init(
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools)
const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema)
{
llama_tool_call_handler handler;
switch (style) {
case llama_tool_call_style::Generic: {
auto tool_call_schemas = json::array();
for (const auto & tool : tools) {
if (tool["type"] != "function") {
continue;
}
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
tool_call_schemas.emplace_back(json {
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", name},
}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
});
}
const auto tool_call = json {{"anyOf", tool_call_schemas}};
const auto schema = json {
{"anyOf", json::array({
parallel_tool_calls
? json {
{"type", "object"},
{"properties", {
{"tool_calls", {
{"type", "array"},
{"items", tool_call}
}},
}},
{"required", json::array({"tool_calls"})},
}
: json {
{"type", "object"},
{"properties", {
{"tool_call", tool_call},
}},
{"required", json::array({"tool_call"})},
},
{
{"type", "object"},
{"properties", {
{"response", json_schema.is_null()
? json {{"type", "string"}}
: json_schema
},
}},
},
})}
};
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
builder.add_schema("", schema);
});
// TODO: add schema to system prompt.
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break;
}
case llama_tool_call_style::Llama31:
case llama_tool_call_style::Llama32: {
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};

View file

@ -9,6 +9,7 @@
enum llama_tool_call_style {
UnknownToolCallStyle,
Generic,
Llama31,
Llama32,
FunctionaryV3Llama3,
@ -44,4 +45,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools);
const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema = {});

View file

@ -323,7 +323,7 @@ static json oaicompat_completion_params_parse(
llama_params["chat_template"] = tmpl.source();
if (use_jinja) {
if (has_tools && !tmpl.supports_tools()) {
if (has_tools && tool_call_style == llama_tool_call_style::UnknownToolCallStyle) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
}
} else if (has_tools) {
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) {
@ -451,22 +451,26 @@ static json format_final_response_oaicompat(const json & request, const json & r
auto tools = json_value(request, "tools", json::array());
json tool_calls;
json message_content;
if (json_value(request, "parse_tool_calls", false)
&& !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
finish_reason = "tool_calls";
if (!parsed_tool_calls.content.empty()) {
if (json_value(request, "parse_tool_calls", false)) {
parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
if (!parsed_tool_calls.content.empty()) {
message_content = parsed_tool_calls.content;
}
tool_calls = json::array();
for (const auto & tc : parsed_tool_calls.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}}
});
}
} else {
message_content = parsed_tool_calls.content;
}
tool_calls = json::array();
for (const auto & tc : parsed_tool_calls.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}}
});
}
} else {
message_content = content;
}