diff --git a/common/tool-call.cpp b/common/tool-call.cpp index e9b90a72c..6e784a1a9 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -31,7 +31,7 @@ llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { return CommandRPlus; } else { - return UnknownToolCallStyle; + return Generic; } } @@ -212,8 +212,32 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, cons return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } +static llama_tool_calls parse_generic_tool_calls(const std::string& input) { + json data = json::parse(input); + llama_tool_calls result; + if (data.contains("tool_calls")) { + for (const auto & tool_call : data["tool_calls"]) { + result.tool_calls.push_back({ + tool_call["name"], + tool_call["arguments"].dump(), + }); + } + } else if (data.contains("tool_call")) { + result.tool_calls.push_back({ + data["tool_call"]["name"], + data["tool_call"]["arguments"].dump(), + }); + } else if (data.contains("response")) { + const auto & response = data["response"]; + result.content = response.is_string() ? response.get() : response.dump(2); + } + return result; +} + llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { switch (style) { + case llama_tool_call_style::Generic: + return parse_generic_tool_calls(input); case llama_tool_call_style::Llama31: return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true); case llama_tool_call_style::Llama32: @@ -235,11 +259,72 @@ llama_tool_call_handler llama_tool_call_handler_init( bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools) + const nlohmann::ordered_json & tools, + const nlohmann::ordered_json & json_schema) { llama_tool_call_handler handler; switch (style) { + case llama_tool_call_style::Generic: { + auto tool_call_schemas = json::array(); + for (const auto & tool : tools) { + if (tool["type"] != "function") { + continue; + } + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + tool_call_schemas.emplace_back(json { + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", name}, + }}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + }); + } + const auto tool_call = json {{"anyOf", tool_call_schemas}}; + const auto schema = json { + {"anyOf", json::array({ + parallel_tool_calls + ? json { + {"type", "object"}, + {"properties", { + {"tool_calls", { + {"type", "array"}, + {"items", tool_call} + }}, + }}, + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", tool_call}, + }}, + {"required", json::array({"tool_call"})}, + }, + { + {"type", "object"}, + {"properties", { + {"response", json_schema.is_null() + ? json {{"type", "string"}} + : json_schema + }, + }}, + }, + })} + }; + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + builder.add_schema("", schema); + }); + // TODO: add schema to system prompt. + handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); + break; + } case llama_tool_call_style::Llama31: case llama_tool_call_style::Llama32: { static auto builtin_tools = json {"wolfram_alpha", "brave_search"}; diff --git a/common/tool-call.h b/common/tool-call.h index dc505ba2d..b6911f22e 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -9,6 +9,7 @@ enum llama_tool_call_style { UnknownToolCallStyle, + Generic, Llama31, Llama32, FunctionaryV3Llama3, @@ -44,4 +45,5 @@ llama_tool_call_handler llama_tool_call_handler_init( bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools); + const nlohmann::ordered_json & tools, + const nlohmann::ordered_json & json_schema = {}); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index aff2a9554..fc66fb591 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -323,7 +323,7 @@ static json oaicompat_completion_params_parse( llama_params["chat_template"] = tmpl.source(); if (use_jinja) { - if (has_tools && !tmpl.supports_tools()) { + if (has_tools && tool_call_style == llama_tool_call_style::UnknownToolCallStyle) { throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); } } else if (has_tools) { @@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse( llama_params["parse_tool_calls"] = true; llama_params["parallel_tool_calls"] = parallel_tool_calls; - auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools); + auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]); llama_params["prompt"] = handler.prompt; for (const auto & stop : handler.additional_stop_words) { @@ -451,22 +451,26 @@ static json format_final_response_oaicompat(const json & request, const json & r auto tools = json_value(request, "tools", json::array()); json tool_calls; json message_content; - if (json_value(request, "parse_tool_calls", false) - && !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) { - finish_reason = "tool_calls"; - if (!parsed_tool_calls.content.empty()) { + if (json_value(request, "parse_tool_calls", false)) { + parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content); + if (!parsed_tool_calls.tool_calls.empty()) { + finish_reason = "tool_calls"; + if (!parsed_tool_calls.content.empty()) { + message_content = parsed_tool_calls.content; + } + tool_calls = json::array(); + for (const auto & tc : parsed_tool_calls.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }} + }); + } + } else { message_content = parsed_tool_calls.content; } - tool_calls = json::array(); - for (const auto & tc : parsed_tool_calls.tool_calls) { - tool_calls.push_back({ - {"type", "function"}, - {"function", { - {"name", tc.name}, - {"arguments", tc.arguments}, - }} - }); - } } else { message_content = content; }