tool-calls
: add generic tool call style as default
This commit is contained in:
parent
fa8462ffd3
commit
9f5ab97756
3 changed files with 110 additions and 19 deletions
|
@ -31,7 +31,7 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template &
|
||||||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||||
return CommandRPlus;
|
return CommandRPlus;
|
||||||
} else {
|
} else {
|
||||||
return UnknownToolCallStyle;
|
return Generic;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,8 +212,32 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, cons
|
||||||
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
|
||||||
|
json data = json::parse(input);
|
||||||
|
llama_tool_calls result;
|
||||||
|
if (data.contains("tool_calls")) {
|
||||||
|
for (const auto & tool_call : data["tool_calls"]) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
tool_call["name"],
|
||||||
|
tool_call["arguments"].dump(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (data.contains("tool_call")) {
|
||||||
|
result.tool_calls.push_back({
|
||||||
|
data["tool_call"]["name"],
|
||||||
|
data["tool_call"]["arguments"].dump(),
|
||||||
|
});
|
||||||
|
} else if (data.contains("response")) {
|
||||||
|
const auto & response = data["response"];
|
||||||
|
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
||||||
switch (style) {
|
switch (style) {
|
||||||
|
case llama_tool_call_style::Generic:
|
||||||
|
return parse_generic_tool_calls(input);
|
||||||
case llama_tool_call_style::Llama31:
|
case llama_tool_call_style::Llama31:
|
||||||
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
|
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
|
||||||
case llama_tool_call_style::Llama32:
|
case llama_tool_call_style::Llama32:
|
||||||
|
@ -235,11 +259,72 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools)
|
const nlohmann::ordered_json & tools,
|
||||||
|
const nlohmann::ordered_json & json_schema)
|
||||||
{
|
{
|
||||||
llama_tool_call_handler handler;
|
llama_tool_call_handler handler;
|
||||||
|
|
||||||
switch (style) {
|
switch (style) {
|
||||||
|
case llama_tool_call_style::Generic: {
|
||||||
|
auto tool_call_schemas = json::array();
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (tool["type"] != "function") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
tool_call_schemas.emplace_back(json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"name", {
|
||||||
|
{"type", "string"},
|
||||||
|
{"const", name},
|
||||||
|
}},
|
||||||
|
{"arguments", parameters},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"name", "arguments"})},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const auto tool_call = json {{"anyOf", tool_call_schemas}};
|
||||||
|
const auto schema = json {
|
||||||
|
{"anyOf", json::array({
|
||||||
|
parallel_tool_calls
|
||||||
|
? json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"tool_calls", {
|
||||||
|
{"type", "array"},
|
||||||
|
{"items", tool_call}
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"tool_calls"})},
|
||||||
|
}
|
||||||
|
: json {
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"tool_call", tool_call},
|
||||||
|
}},
|
||||||
|
{"required", json::array({"tool_call"})},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{"type", "object"},
|
||||||
|
{"properties", {
|
||||||
|
{"response", json_schema.is_null()
|
||||||
|
? json {{"type", "string"}}
|
||||||
|
: json_schema
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
};
|
||||||
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
|
builder.add_schema("", schema);
|
||||||
|
});
|
||||||
|
// TODO: add schema to system prompt.
|
||||||
|
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case llama_tool_call_style::Llama31:
|
case llama_tool_call_style::Llama31:
|
||||||
case llama_tool_call_style::Llama32: {
|
case llama_tool_call_style::Llama32: {
|
||||||
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
enum llama_tool_call_style {
|
enum llama_tool_call_style {
|
||||||
UnknownToolCallStyle,
|
UnknownToolCallStyle,
|
||||||
|
Generic,
|
||||||
Llama31,
|
Llama31,
|
||||||
Llama32,
|
Llama32,
|
||||||
FunctionaryV3Llama3,
|
FunctionaryV3Llama3,
|
||||||
|
@ -44,4 +45,5 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & messages,
|
const nlohmann::ordered_json & messages,
|
||||||
const nlohmann::ordered_json & tools);
|
const nlohmann::ordered_json & tools,
|
||||||
|
const nlohmann::ordered_json & json_schema = {});
|
||||||
|
|
|
@ -323,7 +323,7 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["chat_template"] = tmpl.source();
|
llama_params["chat_template"] = tmpl.source();
|
||||||
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
if (has_tools && !tmpl.supports_tools()) {
|
if (has_tools && tool_call_style == llama_tool_call_style::UnknownToolCallStyle) {
|
||||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
||||||
}
|
}
|
||||||
} else if (has_tools) {
|
} else if (has_tools) {
|
||||||
|
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
|
||||||
llama_params["parse_tool_calls"] = true;
|
llama_params["parse_tool_calls"] = true;
|
||||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||||
|
|
||||||
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]);
|
||||||
llama_params["prompt"] = handler.prompt;
|
llama_params["prompt"] = handler.prompt;
|
||||||
|
|
||||||
for (const auto & stop : handler.additional_stop_words) {
|
for (const auto & stop : handler.additional_stop_words) {
|
||||||
|
@ -451,8 +451,9 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||||
auto tools = json_value(request, "tools", json::array());
|
auto tools = json_value(request, "tools", json::array());
|
||||||
json tool_calls;
|
json tool_calls;
|
||||||
json message_content;
|
json message_content;
|
||||||
if (json_value(request, "parse_tool_calls", false)
|
if (json_value(request, "parse_tool_calls", false)) {
|
||||||
&& !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
|
parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content);
|
||||||
|
if (!parsed_tool_calls.tool_calls.empty()) {
|
||||||
finish_reason = "tool_calls";
|
finish_reason = "tool_calls";
|
||||||
if (!parsed_tool_calls.content.empty()) {
|
if (!parsed_tool_calls.content.empty()) {
|
||||||
message_content = parsed_tool_calls.content;
|
message_content = parsed_tool_calls.content;
|
||||||
|
@ -467,6 +468,9 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||||
}}
|
}}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
message_content = parsed_tool_calls.content;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
message_content = content;
|
message_content = content;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue