tool-calls
: add generic tool call style as default
This commit is contained in:
parent
fa8462ffd3
commit
9f5ab97756
3 changed files with 110 additions and 19 deletions
|
@ -31,7 +31,7 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template &
|
|||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
return CommandRPlus;
|
||||
} else {
|
||||
return UnknownToolCallStyle;
|
||||
return Generic;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -212,8 +212,32 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, cons
|
|||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||
}
|
||||
|
||||
static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||
json data = json::parse(input);
|
||||
llama_tool_calls result;
|
||||
if (data.contains("tool_calls")) {
|
||||
for (const auto & tool_call : data["tool_calls"]) {
|
||||
result.tool_calls.push_back({
|
||||
tool_call["name"],
|
||||
tool_call["arguments"].dump(),
|
||||
});
|
||||
}
|
||||
} else if (data.contains("tool_call")) {
|
||||
result.tool_calls.push_back({
|
||||
data["tool_call"]["name"],
|
||||
data["tool_call"]["arguments"].dump(),
|
||||
});
|
||||
} else if (data.contains("response")) {
|
||||
const auto & response = data["response"];
|
||||
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
||||
switch (style) {
|
||||
case llama_tool_call_style::Generic:
|
||||
return parse_generic_tool_calls(input);
|
||||
case llama_tool_call_style::Llama31:
|
||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
|
||||
case llama_tool_call_style::Llama32:
|
||||
|
@ -235,11 +259,72 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools)
|
||||
const nlohmann::ordered_json & tools,
|
||||
const nlohmann::ordered_json & json_schema)
|
||||
{
|
||||
llama_tool_call_handler handler;
|
||||
|
||||
switch (style) {
|
||||
case llama_tool_call_style::Generic: {
|
||||
auto tool_call_schemas = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
if (tool["type"] != "function") {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
tool_call_schemas.emplace_back(json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", name},
|
||||
}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
});
|
||||
}
|
||||
const auto tool_call = json {{"anyOf", tool_call_schemas}};
|
||||
const auto schema = json {
|
||||
{"anyOf", json::array({
|
||||
parallel_tool_calls
|
||||
? json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_calls", {
|
||||
{"type", "array"},
|
||||
{"items", tool_call}
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"tool_calls"})},
|
||||
}
|
||||
: json {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"tool_call", tool_call},
|
||||
}},
|
||||
{"required", json::array({"tool_call"})},
|
||||
},
|
||||
{
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"response", json_schema.is_null()
|
||||
? json {{"type", "string"}}
|
||||
: json_schema
|
||||
},
|
||||
}},
|
||||
},
|
||||
})}
|
||||
};
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
builder.add_schema("", schema);
|
||||
});
|
||||
// TODO: add schema to system prompt.
|
||||
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
case llama_tool_call_style::Llama31:
|
||||
case llama_tool_call_style::Llama32: {
|
||||
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
enum llama_tool_call_style {
|
||||
UnknownToolCallStyle,
|
||||
Generic,
|
||||
Llama31,
|
||||
Llama32,
|
||||
FunctionaryV3Llama3,
|
||||
|
@ -44,4 +45,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools);
|
||||
const nlohmann::ordered_json & tools,
|
||||
const nlohmann::ordered_json & json_schema = {});
|
||||
|
|
|
@ -323,7 +323,7 @@ static json oaicompat_completion_params_parse(
|
|||
llama_params["chat_template"] = tmpl.source();
|
||||
|
||||
if (use_jinja) {
|
||||
if (has_tools && !tmpl.supports_tools()) {
|
||||
if (has_tools && tool_call_style == llama_tool_call_style::UnknownToolCallStyle) {
|
||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
||||
}
|
||||
} else if (has_tools) {
|
||||
|
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
|
|||
llama_params["parse_tool_calls"] = true;
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
|
||||
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
||||
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
|
||||
llama_params["prompt"] = handler.prompt;
|
||||
|
||||
for (const auto & stop : handler.additional_stop_words) {
|
||||
|
@ -451,22 +451,26 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||
auto tools = json_value(request, "tools", json::array());
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
if (json_value(request, "parse_tool_calls", false)
|
||||
&& !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
|
||||
finish_reason = "tool_calls";
|
||||
if (!parsed_tool_calls.content.empty()) {
|
||||
if (json_value(request, "parse_tool_calls", false)) {
|
||||
parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content);
|
||||
if (!parsed_tool_calls.tool_calls.empty()) {
|
||||
finish_reason = "tool_calls";
|
||||
if (!parsed_tool_calls.content.empty()) {
|
||||
message_content = parsed_tool_calls.content;
|
||||
}
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : parsed_tool_calls.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
message_content = parsed_tool_calls.content;
|
||||
}
|
||||
tool_calls = json::array();
|
||||
for (const auto & tc : parsed_tool_calls.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
message_content = content;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue