From 36ed106f84f37f2a71167a3845171fff7bed052f Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 24 Jan 2025 02:31:37 +0000 Subject: [PATCH] WIP chat handlers --- Makefile | 3 +- common/CMakeLists.txt | 3 +- common/chat-handler.cpp | 720 +++++++++++++++++ common/chat-handler.hpp | 43 + common/chat-template.hpp | 111 ++- common/common.h | 11 +- common/sampling.cpp | 2 +- common/tool-call.cpp | 747 ------------------ common/tool-call.h | 44 -- examples/server/server.cpp | 162 ++-- examples/server/utils.hpp | 46 +- tests/CMakeLists.txt | 2 +- ...st-tool-call.cpp => test-chat-handler.cpp} | 176 +++-- 13 files changed, 1071 insertions(+), 999 deletions(-) create mode 100644 common/chat-handler.cpp create mode 100644 common/chat-handler.hpp delete mode 100644 common/tool-call.cpp delete mode 100644 common/tool-call.h rename tests/{test-tool-call.cpp => test-chat-handler.cpp} (74%) diff --git a/Makefile b/Makefile index e9a093cbb..ed04dc176 100644 --- a/Makefile +++ b/Makefile @@ -1363,10 +1363,11 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ + common/chat-handler.cpp \ + common/chat-handler.hpp \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ - common/tool-call.h \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index bea32bfbe..0cfc8b3d0 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-handler.cpp + chat-handler.hpp chat-template.hpp common.cpp common.h @@ -72,7 +74,6 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h - tool-call.cpp ) if (BUILD_SHARED_LIBS) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp new file mode 100644 index 000000000..0c0aba5e9 --- /dev/null +++ b/common/chat-handler.cpp @@ -0,0 +1,720 @@ +#include "chat-handler.hpp" +#include "chat-template.hpp" +#include "json-schema-to-grammar.h" +#include "minja.hpp" + +const common_grammar_options grammar_options { + /* .dotall = */ false, + /* .compact_spaces = */ false, + // /* .compact_spaces = */ true, +}; + +static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + this->position = position - 1; + this->found_error = true; + return false; + } + bool null() override { return true; } + bool boolean(bool) override { return true; } + bool number_integer(number_integer_t) override { return true; } + bool number_unsigned(number_unsigned_t) override { return true; } + bool number_float(number_float_t, const string_t &) override { return true; } + bool string(string_t &) override { return true; } + bool binary(binary_t &) override { return true; } + bool start_object(std::size_t) override { return true; } + bool key(string_t &) override { return true; } + bool end_object() override { return true; } + bool start_array(std::size_t) override { return true; } + bool end_array() override { return true; } + }; + json_error_locator err_loc; + json::sax_parse(it, end, &err_loc); + + std::string::const_iterator temptative_end; + if (err_loc.found_error) { + temptative_end = it + err_loc.position; + } else { + temptative_end = end; + } + std::string json_sub {it, temptative_end}; + try { + out = json::parse(json_sub); + it = temptative_end; + return true; + } catch (const std::exception &) { + return false; + } +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { + std::smatch match; + + common_chat_msg result; + result.role = "assistant"; + auto end = input.end(); + auto it = input.begin(); + + std::vector tool_names; + if (check_names) { + for (const auto & tool : tools) { + if (!tool.contains("type")) { + continue; + } + std::string type = tool.at("type"); + if (type == "function") { + tool_names.push_back(tool["function"]["name"]); + } else if (type == "code_interpreter") { + tool_names.push_back("ipython"); + } + } + } + + while (it != end) { + std::sregex_iterator rend; + std::sregex_iterator rit(it, end, function_regex); + if (rit == rend) { + fprintf(stderr, "No more tool calls found\n"); + result.content += std::string(it, end); + break; + } + auto name = rit->str(1); + if (check_names && std::find(tool_names.begin(), tool_names.end(), name) == tool_names.end()) { + fprintf(stderr, "Skipping unknown tool name: %s (known tools: %s)\n", name.c_str(), string_join(tool_names, ", ").c_str()); + result.content += std::string(it, rit->suffix().first); + break; + } + + result.content += std::string(it, rit->prefix().second); + it = rit->suffix().first; + + + json arguments; + if (!parse_json(it, end, arguments)) { + throw std::runtime_error("Failed to parse json tool call arguments"); + } + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern"); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); + } + return result; +} + +static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { + auto content_end = input.find(prefix); + size_t tc_start = std::string::npos; + + common_chat_msg result; + result.role = "assistant"; + const auto process_tool_calls = [&](const json & tool_calls) { + for (const auto & tool_call : tool_calls) { + const auto & arguments = tool_call["arguments"]; + result.tool_calls.push_back({ + tool_call["name"], + arguments.is_string() ? arguments.get() : arguments.dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + }; + if (content_end == std::string::npos) { + result.content = input; + } else { + tc_start = content_end + prefix.size() - rstrip_prefix; + result.content = input.substr(0, content_end); + auto tool_calls = json::parse(input.substr(tc_start)); + process_tool_calls(tool_calls); + } + return result; +} + +static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { + json messages_with_system = messages; + + if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + std::string existing_system = messages_with_system.at(0).at("content"); + messages_with_system[0] = json { + {"role", "system"}, + {"content", existing_system + "\n" + system_prompt}, + }; + } else { + messages_with_system.insert(messages_with_system.begin(), json { + {"role", "system"}, + {"content", system_prompt}, + }); + } + return messages_with_system; +} + +class text_chat_parser : public common_chat_parser { +public: + std::optional parse_partial(const std::string & input) override { + return parse_final(input); + } + + common_chat_msg parse_final(const std::string & input) override { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + } +}; + +class monolithic_chat_parser : public common_chat_parser { + + std::string input_buffer_; + std::function parse_final_; + +public: + monolithic_chat_parser(const std::function & parse_final) : parse_final_(parse_final) {} + + std::optional parse_partial(const std::string & input) override { + input_buffer_ += input; + return std::nullopt; + } + + common_chat_msg parse_final(const std::string & input) override { + input_buffer_ += input; + auto out = parse_final_(input_buffer_); + input_buffer_.clear(); + return out; + } +}; + +static common_chat_data build_generic_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + common_chat_data data; + + auto tool_call_schemas = json::array(); + for (const auto & tool : params.tools) { + const auto & function = tool["function"]; + auto tool_schema = json { + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments"})}, + }; + if (function.contains("description")) { + tool_schema["description"] = function["description"]; + } + if (params.parallel_tool_calls) { + tool_schema["properties"]["id"] = { + {"type", "string"}, + {"minLength", 4}, + }; + tool_schema["required"].push_back("id"); + } + tool_call_schemas.emplace_back(tool_schema); + } + const auto tool_call = + params.parallel_tool_calls + ? json { + {"type", "object"}, + {"properties", { + {"tool_calls", { + {"type", "array"}, + {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + {"minItems", 1}, + }}, + }}, + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + }}, + {"required", json::array({"tool_call"})}, + }; + const auto schema = + params.tool_choice != "required" + ? json { + {"anyOf", json::array({ + tool_call, + { + {"type", "object"}, + {"properties", { + {"response", params.json_schema.is_null() + ? json {{"type", "string"}} + : params.json_schema + }, + }}, + {"required", json::array({"response"})}, + }, + })} + } + : tool_call; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_schema("root", schema); + }, grammar_options); + + // TODO: add schema to system prompt. + auto tweaked_messages = add_system( + params.messages, + "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); + + data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([&](const std::string & input) { + json data = json::parse(input); + common_chat_msg result; + result.role = "assistant"; + if (data.contains("tool_calls")) { + for (const auto & tool_call : data["tool_calls"]) { + result.tool_calls.push_back({ + tool_call["name"], + tool_call["arguments"].dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + } else if (data.contains("tool_call")) { + result.tool_calls.push_back({ + data["tool_call"]["name"], + data["tool_call"]["arguments"].dump(), + /* id= */ "", + }); + } else if (data.contains("response")) { + const auto & response = data["response"]; + result.content = response.is_string() ? response.get() : response.dump(2); + } + return result; + }); + return data; +} + +static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + common_chat_data data; + auto builtin_tools = json {"wolfram_alpha", "brave_search"}; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + for (const auto & tool : params.tools) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + // Important note: the model is probably trained to take a JSON stringified arguments value. + // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + {"id", { + {"type", "string"}, + // Nemo's template expects a 9-character alphanumeric ID. + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + } + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!params.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); + } + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([](const std::string & input) -> common_chat_msg { + return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); + }); + return data; +} + +static common_chat_data build_llama_3_tool_calls_handler(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) { + auto builtin_tools = json {"wolfram_alpha", "brave_search"}; + for (const auto & tool : params.tools) { + if (!tool.contains("type")) { + continue; + } + if (tool["type"] == "code_interpreter") { + builtin_tools.push_back("code_interpreter"); + break; + } + } + + common_chat_data data; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + + auto has_python = false; + + for (const auto & tool : params.tools) { + if (!tool.contains("type")) { + continue; + } + + if (tool["type"] == "code_interpreter") { + has_python = true; + } else if (tool["type"] == "function" && tool.contains("function")) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + if (uses_python_tag && (name == "python" || name == "ipython" || builtin_tools.contains(name))) { + has_python = true; + } else { + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + if (params.tool_choice != "required" && !eagerly_match_any_json) { + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ false}); + // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. + // Note that c++11's regex doesn't support partial matches, otherwise it would make + // sense to add support for trigger regexes to the antiprompt mechanism. + data.grammar_triggers.push_back({"{\n\t\"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\n \"name\": \"" + name + "\"", /* .at_start = */ false}); + data.grammar_triggers.push_back({"{\"type\": \"function\", \"name\": \"" + name + "\"", /* .at_start = */ false}); + } + } + } + } + + if (has_python) { + tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + } + + if (params.tool_choice != "required" && eagerly_match_any_json) { + data.grammar_triggers.push_back({"{\"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\n\t\"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\n \"", /* .at_start = */ true}); + } + + builder.add_rule("root", string_join(tool_rules, " | ")); + }, grammar_options); + data.additional_stops.push_back("<|eom_id|>"); + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([params, uses_python_tag](const std::string & input) -> common_chat_msg { + if (uses_python_tag) { + static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); + std::smatch match; + if (std::regex_search(input, match, python_tag_regex)) { + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ "python", + /* .arguments = */ match[1].str(), + /* .id = */ "", + }, + } + }; + } + } + static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); + static std::regex close_regex("\\}"); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + }); + return data; +} + +static common_chat_data build_firefunction_v2_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + common_chat_data data; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + for (const auto & tool : params.tools) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + } + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!params.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); + } + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([](const std::string & input) -> common_chat_msg { + return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); + }); + return data; +} + +static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + common_chat_data data; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector first_tool_rules; + std::vector subsequent_tool_rules; + auto has_python = false; + for (const auto & tool : params.tools) { + if (!tool.contains("type")) { + continue; + } + if (tool["type"] == "code_interpreter") { + has_python = true; + } else if (tool["type"] == "function" && tool.contains("function")) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule)); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({name + "\n", /* .at_start = */ true}); + data.grammar_triggers.push_back({"\n>>>" + name + "\n", /* .at_start = */ false}); + } + } + } + auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + // Note: if there's a python rule, it needs to come last. + auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*"); + if (has_python && params.tool_choice != "required") { + data.grammar_triggers.push_back({"python\n", /* .at_start = */ true}); + data.grammar_triggers.push_back({"\n>>>python\n", /* .at_start = */ false}); + } + if (params.parallel_tool_calls) { + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; + builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : "")); + } else { + builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : "")); + } + }, grammar_options); + + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([params](const std::string & input) -> common_chat_msg { + static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); + static std::regex close_regex(R"($|(?=>>>))"); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true); + }); + return data; +} + +static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + // TODO: handle tool {type: code_interpreter} as python + common_chat_data data; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + auto has_python = false; + for (const auto & tool : params.tools) { + if (!tool.contains("type")) { + continue; + } + if (tool["type"] == "code_interpreter") { + has_python = true; + } else if (tool["type"] == "function" && tool.contains("function")) { + const auto & function = tool["function"]; + std::string name = function["name"]; + if (name == "python" || name == "ipython") { + has_python = true; + } else { + auto parameters = function["parameters"]; + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + } + } + } + if (has_python) { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + } + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; + builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"([params](const std::string & input) -> common_chat_msg { + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); + std::smatch match; + if (std::regex_search(input, match, python_tag_regex)) { + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ "python", + /* .arguments = */ match[1].str(), + /* .id = */ "", + }, + } + }; + } + static std::regex function_regex(R"()"); + static std::regex close_regex(R"()"); + return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ false); + }); + return data; +} + +static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_template & tmpl, const struct common_chat_params & params) { + common_chat_data data; + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + for (const auto & tool : params.tools) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + } + + auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; + builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + if (params.tool_choice != "required") { + data.grammar_triggers.push_back({"", /* .at_start = */ false}); + } + }, grammar_options); + + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique([&](const std::string & input) -> common_chat_msg { + try { + std::regex start_pattern(R"([\n\s]*)"); + std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); + std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + + auto end = input.end(); + std::sregex_iterator rend; + std::sregex_iterator rit(input.begin(), end, start_pattern); + if (rit == rend) { + return {"assistant", input, {}}; + } + + common_chat_msg result; + result.role = "assistant"; + result.content = rit->prefix(); + + auto it = rit->suffix().first; + while (it != end) { + json call; + if (!parse_json(it, end, call)) { + throw std::runtime_error("Failed to parse json tool call"); + } + result.tool_calls.push_back({ + call["name"], + call["arguments"].dump(), + /* id= */ "", + }); + rit = {it, end, middle_pattern}; + if (rit != rend) { + it = rit->suffix().first; + } else { + rit = {it, end, end_pattern}; + if (rit == rend) { + throw std::runtime_error("Malformed input, missing "); + } + break; + } + } + return result; + } catch (const std::exception & e) { + return {"assistant", input, {}}; + } + }); + return data; +} + +common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { + if (params.tools.is_null()) { + common_chat_data data; + data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true); + data.handler = std::make_unique(); + return data; + } + const auto & src = tmpl.source(); + + if (src.find("") != std::string::npos) { + return build_hermes_2_pro_tool_call_handler(tmpl, params); + } + if (src.find(">>>all") != std::string::npos) { + return build_functionary_v3_llama_3_tool_call_handler(tmpl, params); + } + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos; + + // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, + // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon + // as it seems to be outputting some JSON. + // TODO: make this conditional on a very small model (e.g. 1B / 3B). + auto eagerly_match_any_json = false; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; + + return build_llama_3_tool_calls_handler(tmpl, params, uses_python_tag, eagerly_match_any_json); + } + // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { + // TODO: Command-R-Plus + // } + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return build_mistral_nemo_tool_call_handler(tmpl, params); + } + if (src.find(" functools[") != std::string::npos) { + return build_firefunction_v2_tool_call_handler(tmpl, params); + } + return build_generic_tool_call_handler(tmpl, params); +} diff --git a/common/chat-handler.hpp b/common/chat-handler.hpp new file mode 100644 index 000000000..91304ab7e --- /dev/null +++ b/common/chat-handler.hpp @@ -0,0 +1,43 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "common.h" +#include +#include +#include + +using json = nlohmann::ordered_json; + +struct common_chat_params { + json messages; + json tools; + json tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; +}; + +class common_chat_parser { +public: + virtual ~common_chat_parser() = default; + + virtual std::optional parse_partial(const std::string & input) = 0; + virtual common_chat_msg parse_final(const std::string & input) = 0; +}; + +struct common_chat_data { + std::string prompt; + std::string grammar; + std::vector grammar_triggers; + std::vector additional_stops; + std::unique_ptr handler; +}; + +struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 42ee0b615..05f093159 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -20,7 +20,7 @@ namespace minja { class chat_template { public: - private: +// private: bool supports_tools_ = true; // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Most other templates (and OpenAI's API) expect the arguments object to be stringified. @@ -28,6 +28,7 @@ class chat_template { bool requires_typed_content_ = false; bool supports_system_role_ = true; bool supports_parallel_tool_calls_ = false; + bool supports_code_interpreter_ = false; std::string source_; std::string bos_token_; std::string eos_token_; @@ -60,8 +61,29 @@ class chat_template { }); supports_tools_ = source.find("tools") != std::string::npos; - auto renders_string_arguments = + requires_object_arguments_ = try_raw_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", { + {"code", "print('Hello, World!')"}, + }}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos + && try_raw_render({ { {"role", "user"}, {"content", "Hey"} @@ -79,32 +101,8 @@ class chat_template { }, })}, } - }, {}, false).find("{\"code\": \"print") != std::string::npos; - if (!renders_string_arguments) { - auto renders_object_arguments = - try_raw_render({ - { - {"role", "user"}, - {"content", "Hey"} - }, - { - {"role", "assistant"}, - {"tool_calls", json::array({ - { - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"arguments", { - {"code", "print('Hello, World!')"}, - }}, - {"name", "ipython"}, - }}, - }, - })}, - } - }, {}, false).find("{\"code\": \"print") != std::string::npos; - requires_object_arguments_ = renders_object_arguments; - } + }, {}, false).find("{\"code\": \"print") == std::string::npos; + supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; supports_system_role_ = try_raw_render({ @@ -114,6 +112,8 @@ class chat_template { requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos && try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos; + + supports_code_interpreter_ = source.find("code_interpreter") != std::string::npos; } const std::string & source() const { return source_; } @@ -130,8 +130,45 @@ class chat_template { bool adjust_inputs = true) const { json actual_messages; + json actual_tools; - // First, "fix" messages so they have a chance to be rendered correctly by the template + auto has_code_interpreter = false; + for (const auto & tool : tools) { + if (tool.contains("type") && tool.at("type") == "code_interpreter") { + has_code_interpreter = true; + break; + } + } + + if (adjust_inputs && !tools.is_null() && !supports_code_interpreter_ && has_code_interpreter) { + actual_tools = json::array(); + for (const auto & tool : tools) { + if (tool.contains("type") && tool.at("type") == "code_interpreter") { + static const auto python_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "ipython", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } + })"); + actual_tools.push_back(python_tool); + } else { + actual_tools.push_back(tool); + } + } + } else if (!tools.is_null()) { + actual_tools = tools; + } if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) { actual_messages = json::array(); @@ -173,7 +210,12 @@ class chat_template { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); std::string arguments = function.at("arguments"); - function["arguments"] = json::parse(arguments); + try { + function["arguments"] = json::parse(arguments); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + function["arguments"] = arguments; + } } } } @@ -242,6 +284,9 @@ class chat_template { } else { actual_messages = messages; } + // if (adjust_inputs) { + // fprintf(stderr, "Messages: %s\n", actual_messages.dump(2).c_str()); + // } auto context = minja::Context::make(json({ {"messages", actual_messages}, @@ -251,8 +296,12 @@ class chat_template { })); if (!tools.is_null()) { - auto tools_val = minja::Value(tools); + auto tools_val = minja::Value(actual_tools); context->set("tools", tools_val); + if (has_code_interpreter) { + auto builtin_tools_val = minja::Value(json {"code_interpreter"}); + context->set("builtin_tools", builtin_tools_val); + } } if (!extra_context.is_null()) { for (auto & kv : extra_context.items()) { diff --git a/common/common.h b/common/common.h index 96e23689e..e075d39dd 100644 --- a/common/common.h +++ b/common/common.h @@ -109,6 +109,11 @@ enum common_conversation_mode { COMMON_CONVERSATION_MODE_AUTO = 2, }; +struct common_grammar_trigger { + std::string word; + bool at_start; +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -154,9 +159,9 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TEMPERATURE, }; - std::string grammar; // optional BNF-like grammar to constrain sampling - std::vector grammar_trigger_words; // optional trigger words to enable grammar - std::vector grammar_trigger_tokens; // optional trigger tokens to enable grammar + std::string grammar; // optional BNF-like grammar to constrain sampling + std::vector grammar_trigger_words; // optional trigger words to enable grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to enable grammar std::vector logit_bias; // logit biases to apply diff --git a/common/sampling.cpp b/common/sampling.cpp index 573c61d8c..08ecb4599 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co std::vector trigger_words; trigger_words.reserve(params.grammar_trigger_words.size()); for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.c_str()); + trigger_words.push_back(str.word.c_str()); } auto * result = new common_sampler { /* .params = */ params, diff --git a/common/tool-call.cpp b/common/tool-call.cpp deleted file mode 100644 index 01fce7e10..000000000 --- a/common/tool-call.cpp +++ /dev/null @@ -1,747 +0,0 @@ -#include "tool-call.h" -#include "json-schema-to-grammar.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -static json normalize_tools(const json & tools) { - static const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "Runs code in an Python interpreter and returns the result of the execution after 60 seconds.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The code to run in the Python interpreter." - } - }, - "required": ["code"] - } - } - })"); - - auto results = json::array(); - for (const auto & tool : tools) { - if (!tool.contains("type")) { - continue; - } - if (tool["type"] == "code_interpreter") { - results.push_back(python_tool); - } else if (tool["type"] == "function") { - results.push_back(tool); - } else { - continue; - } - } - return results; -} - -std::string common_tool_call_style_name(common_tool_call_style style) { - switch (style) { - case COMMON_TOOL_CALL_STYLE_NONE: - return "None"; - case COMMON_TOOL_CALL_STYLE_GENERIC: - return "Generic"; - case COMMON_TOOL_CALL_STYLE_LLAMA_3_1: - return "Llama-3.1"; - case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: - return "Llama-3.2"; - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: - return "FunctionaryV3Llama3"; - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: - return "FunctionaryV3Llama3.1"; - case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: - return "Hermes2Pro"; - case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS: - return "CommandRPlus"; - case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: - return "MistralNemo"; - case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: - return "FirefunctionV2"; - default: - return "Unknown"; - } -} - -common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) { - const auto & src = chat_template.source(); - - if (src.find("") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO; - } else if (src.find(">>>all") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3; - } else if (src.find("<|start_header_id|>") != std::string::npos - && src.find("ipython<|end_header_id|>") != std::string::npos) { - if (src.find("<|python_tag|>") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_LLAMA_3_1; - } else { - return COMMON_TOOL_CALL_STYLE_LLAMA_3_2; - } - } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS; - } else if (src.find("[TOOL_CALLS]") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO; - } else if (src.find(" functools[") != std::string::npos) { - return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2; - } else { - return COMMON_TOOL_CALL_STYLE_GENERIC; - } -} - -static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { - // // https://json.nlohmann.me/features/parsing/sax_interface/ - struct json_error_locator : public nlohmann::json_sax { - std::size_t position; - bool found_error; - - json_error_locator() : position(0), found_error(false) {} - - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { - this->position = position - 1; - this->found_error = true; - return false; - } - bool null() override { return true; } - bool boolean(bool) override { return true; } - bool number_integer(number_integer_t) override { return true; } - bool number_unsigned(number_unsigned_t) override { return true; } - bool number_float(number_float_t, const string_t &) override { return true; } - bool string(string_t &) override { return true; } - bool binary(binary_t &) override { return true; } - bool start_object(std::size_t) override { return true; } - bool key(string_t &) override { return true; } - bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } - bool end_array() override { return true; } - }; - json_error_locator err_loc; - json::sax_parse(it, end, &err_loc); - - std::string::const_iterator temptative_end; - if (err_loc.found_error) { - temptative_end = it + err_loc.position; - } else { - temptative_end = end; - } - std::string json_sub {it, temptative_end}; - try { - out = json::parse(json_sub); - it = temptative_end; - return true; - } catch (const std::exception &) { - return false; - } -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static common_chat_msg parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { - std::smatch match; - - common_chat_msg result; - result.role = "assistant"; - auto end = input.end(); - auto it = input.begin(); - - std::unordered_set tool_names; - if (check_names) { - for (const auto & tool : tools) { - if (!tool.contains("type")) { - continue; - } - std::string type = tool.at("type"); - if (type == "function") { - tool_names.insert(tool["function"]["name"]); - } else if (type == "code_interpreter") { - tool_names.insert("python"); - } - } - } - - while (it != end) { - std::sregex_iterator rend; - std::sregex_iterator rit(it, end, function_regex); - if (rit == rend) { - result.content += std::string(it, end); - break; - } - auto name = rit->str(1); - if (check_names && tool_names.find(name) == tool_names.end()) { - result.content += std::string(it, rit->suffix().first); - break; - } - - result.content += std::string(it, rit->prefix().second); - it = rit->suffix().first; - - - json arguments; - if (!parse_json(it, end, arguments)) { - throw std::runtime_error("Failed to parse json tool call arguments"); - } - if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern"); - } - it = match.suffix().first; - result.tool_calls.push_back({name, arguments.dump(), /* id= */ ""}); - } - return result; -} - -static common_chat_msg parse_hermes_tool_calls(const std::string& input) { - try { - std::regex start_pattern(R"([\n\s]*)"); - std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); - std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); - - auto end = input.end(); - std::sregex_iterator rend; - std::sregex_iterator rit(input.begin(), end, start_pattern); - if (rit == rend) { - return {"assistant", input, {}}; - } - - common_chat_msg result; - result.role = "assistant"; - result.content = rit->prefix(); - - auto it = rit->suffix().first; - while (it != end) { - json call; - if (!parse_json(it, end, call)) { - throw std::runtime_error("Failed to parse json tool call"); - } - result.tool_calls.push_back({ - call["name"], - call["arguments"].dump(), - /* id= */ "", - }); - rit = {it, end, middle_pattern}; - if (rit != rend) { - it = rit->suffix().first; - } else { - rit = {it, end, end_pattern}; - if (rit == rend) { - throw std::runtime_error("Malformed input, missing "); - } - break; - } - } - return result; - } catch (const std::exception & e) { - return {"assistant", input, {}}; - } -} - -static common_chat_msg parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { - if (allow_python_tag) { - static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); - std::smatch match; - if (std::regex_search(input, match, python_tag_regex)) { - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ (json {{"code", match[1].str()}}).dump(), - /* .id = */ "", - }, - } - }; - } - } - static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); - static std::regex close_regex("\\}"); - return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); -} - -static common_chat_msg parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); - std::smatch match; - if (std::regex_search(input, match, python_tag_regex)) { - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ (json {{"code", match[1].str()}}).dump(), - /* .id = */ "", - }, - } - }; - } - static std::regex function_regex(R"()"); - static std::regex close_regex(R"()"); - return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); -} - -static common_chat_msg parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { - static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); - static std::regex close_regex(R"($|(?=>>>))"); - return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); -} - -static common_chat_msg parse_generic_tool_calls(const std::string& input) { - json data = json::parse(input); - common_chat_msg result; - result.role = "assistant"; - if (data.contains("tool_calls")) { - for (const auto & tool_call : data["tool_calls"]) { - result.tool_calls.push_back({ - tool_call["name"], - tool_call["arguments"].dump(), - tool_call.contains("id") ? tool_call["id"] : "", - }); - } - } else if (data.contains("tool_call")) { - result.tool_calls.push_back({ - data["tool_call"]["name"], - data["tool_call"]["arguments"].dump(), - /* id= */ "", - }); - } else if (data.contains("response")) { - const auto & response = data["response"]; - result.content = response.is_string() ? response.get() : response.dump(2); - } - return result; -} - -static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { - auto content_end = input.find(prefix); - size_t tc_start = std::string::npos; - - common_chat_msg result; - result.role = "assistant"; - const auto process_tool_calls = [&](const json & tool_calls) { - for (const auto & tool_call : tool_calls) { - const auto & arguments = tool_call["arguments"]; - result.tool_calls.push_back({ - tool_call["name"], - arguments.is_string() ? arguments.get() : arguments.dump(), - tool_call.contains("id") ? tool_call["id"] : "", - }); - } - }; - if (content_end == std::string::npos) { - result.content = input; - } else { - tc_start = content_end + prefix.size() - rstrip_prefix; - result.content = input.substr(0, content_end); - auto tool_calls = json::parse(input.substr(tc_start)); - process_tool_calls(tool_calls); - } - return result; -} - -static common_chat_msg parse_mistral_nemo_tool_calls(const std::string& input) { - return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); -} - -static common_chat_msg parse_firefunction_v2_tool_calls(const std::string& input) { - return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); -} - -common_chat_msg parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) { - fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str()); - switch (style) { - case COMMON_TOOL_CALL_STYLE_NONE: - return {"assistant", input, {}}; - case COMMON_TOOL_CALL_STYLE_GENERIC: - return parse_generic_tool_calls(input); - case COMMON_TOOL_CALL_STYLE_LLAMA_3_1: - return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true); - case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: - return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false); - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: - return parse_functionary_v3_tool_calls(tools, input); - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: - return parse_functionary_v3_llama_3_1_tool_calls(tools, input); - case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: - return parse_hermes_tool_calls(input); - case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: - return parse_mistral_nemo_tool_calls(input); - case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: - return parse_firefunction_v2_tool_calls(input); - default: - throw std::runtime_error("Unsupported tool call style"); - } -} - -static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { - json messages_with_system = messages; - - if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { - std::string existing_system = messages_with_system.at(0).at("content"); - messages_with_system[0] = json { - {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, - }; - } else { - messages_with_system.insert(messages_with_system.begin(), json { - {"role", "system"}, - {"content", system_prompt}, - }); - } - return messages_with_system; -} - -common_tool_call_handler common_tool_call_handler_init( - common_tool_call_style style, - const common_chat_template & tmpl, - bool allow_content, - const nlohmann::ordered_json & parallel_tool_calls, - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - const nlohmann::ordered_json & json_schema) -{ - common_grammar_options grammar_options { - /* .dotall = */ false, - /* .compact_spaces = */ true, - }; - common_tool_call_handler handler; - auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get(); - - switch (style) { - case COMMON_TOOL_CALL_STYLE_NONE: - handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); - break; - case COMMON_TOOL_CALL_STYLE_GENERIC: { - auto actual_tools = normalize_tools(tools); - auto tool_call_schemas = json::array(); - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - auto tool_schema = json { - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function["name"]}, - }}, - {"arguments", function["parameters"]}, - }}, - {"required", json::array({"name", "arguments"})}, - }; - if (function.contains("description")) { - tool_schema["description"] = function["description"]; - } - if (parallel) { - tool_schema["properties"]["id"] = { - {"type", "string"}, - {"minLength", 4}, - }; - tool_schema["required"].push_back("id"); - } - tool_call_schemas.emplace_back(tool_schema); - } - const auto tool_call = - parallel - ? json { - {"type", "object"}, - {"properties", { - {"tool_calls", { - {"type", "array"}, - {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - {"minItems", 1}, - }}, - }}, - {"required", json::array({"tool_calls"})}, - } - : json { - {"type", "object"}, - {"properties", { - {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - }}, - {"required", json::array({"tool_call"})}, - }; - const auto schema = - allow_content - ? json { - {"anyOf", json::array({ - tool_call, - { - {"type", "object"}, - {"properties", { - {"response", json_schema.is_null() - ? json {{"type", "string"}} - : json_schema - }, - }}, - {"required", json::array({"response"})}, - }, - })} - } - : tool_call; - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - builder.add_schema("root", schema); - }, grammar_options); - // TODO: add schema to system prompt. - auto tweaked_messages = add_system( - messages, - "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); - handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); - break; - } - case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: { - auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - schemas.push_back({ - {"type", "object"}, - {"properties", { - // Important note: the model is probably trained to take a JSON stringified arguments value. - // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. - {"name", { - {"type", "string"}, - {"const", function["name"]}, - }}, - {"arguments", function["parameters"]}, - {"id", { - {"type", "string"}, - // Nemo's template expects a 9-character alphanumeric ID. - {"pattern", "^[a-zA-Z0-9]{9}$"}, - }}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - } - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!parallel) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }, grammar_options); - if (allow_content) { - handler.grammar_triggers.push_back("[TOOL_CALLS]"); - } - handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); - break; - } - case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: { - auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function["name"]}, - }}, - {"arguments", function["parameters"]}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - } - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!parallel) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); - }, grammar_options); - if (allow_content) { - handler.grammar_triggers.push_back(" functools["); - } - handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); - break; - } - case COMMON_TOOL_CALL_STYLE_LLAMA_3_1: - case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: { - auto builtin_tools = json {"wolfram_alpha", "brave_search"}; - for (const auto & tool : tools) { - if (!tool.contains("type")) { - continue; - } - if (tool["type"] == "code_interpreter") { - builtin_tools.push_back("code_interpreter"); - break; - } - } - auto actual_tools = normalize_tools(tools); - - auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1; - - // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, - // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon - // as it seems to be outputting some JSON. - // TODO: make this conditional on a very small model (e.g. 1B / 3B). - auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; - - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - if (uses_python_tag && (name == "ipython" || builtin_tools.contains(name))) { - tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); - if (allow_content) { - handler.grammar_triggers.push_back("<|python_tag|>"); - } - } else { - //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + - " \"}\"")); - if (allow_content && !eagerly_match_any_json) { - handler.grammar_triggers.push_back("{\"name\": \"" + name + "\""); - // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. - // Note that c++11's regex doesn't support partial matches, otherwise it would make - // sense to add support for trigger regexes to the antiprompt mechanism. - handler.grammar_triggers.push_back("{\n\t\"name\": \"" + name + "\""); - handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\""); - handler.grammar_triggers.push_back("{\n \"name\": \"" + name + "\""); - handler.grammar_triggers.push_back("{\"type\": \"function\", \"name\": \"" + name + "\""); - } - } - } - - if (allow_content && eagerly_match_any_json) { - handler.grammar_triggers.push_back("{\""); - handler.grammar_triggers.push_back("{\n\t\""); - handler.grammar_triggers.push_back("{\n \""); - handler.grammar_triggers.push_back("{\n \""); - } - - builder.add_rule("root", string_join(tool_rules, " | ")); - }, grammar_options); - handler.additional_stops.push_back("<|eom_id|>"); - handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, { - {"builtin_tools", builtin_tools}, - }); - break; - } - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: { - // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... - // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar - auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector first_tool_rules; - std::vector subsequent_tool_rules; - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\"\\n>>>" + name + "\\n\" " + args_rule)); - if (allow_content) { - handler.grammar_triggers.push_back(name + "\n"); - handler.grammar_triggers.push_back("\n>>>" + name + "\n"); - } - } - auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - if (parallel) { - auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); - } else { - builder.add_rule("root", first_rule); - } - }, grammar_options); - handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); - // handler.parser = parse_functionary_3_2_tool_calls; - break; - } - case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: { - // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja - // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt - // TODO: handle tool {type: code_interpreter} as python - auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - if (name == "python" || name == "ipython") { - tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - if (allow_content) { - handler.grammar_triggers.push_back("<|python_tag|>"); - } - } else { - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); - } - } - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; - builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); - if (allow_content) { - handler.grammar_triggers.push_back("{"name": "foo", "arguments": {"a": 1}})* - auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - for (const auto & tool : actual_tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - } - - auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; - builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call); - if (allow_content) { - handler.grammar_triggers.push_back(""); - } - }, grammar_options); - handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); - break; - } - default: - throw std::runtime_error("Unsupported tool call style"); - } - return handler; -} diff --git a/common/tool-call.h b/common/tool-call.h deleted file mode 100644 index 37b5d9739..000000000 --- a/common/tool-call.h +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include "ggml.h" -#include "common.h" -#include "chat-template.hpp" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" - -enum common_tool_call_style { - COMMON_TOOL_CALL_STYLE_UNKNOWN, - COMMON_TOOL_CALL_STYLE_NONE, - COMMON_TOOL_CALL_STYLE_GENERIC, - COMMON_TOOL_CALL_STYLE_LLAMA_3_1, - COMMON_TOOL_CALL_STYLE_LLAMA_3_2, - COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3, - COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1, - COMMON_TOOL_CALL_STYLE_HERMES_2_PRO, - COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS, - COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO, - COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2, -}; - -struct common_tool_call_handler { - std::string prompt; - std::string grammar; - std::vector grammar_triggers; - std::vector additional_stops; -}; - -std::string common_tool_call_style_name(common_tool_call_style style); - -common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template); - -common_chat_msg parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); - -common_tool_call_handler common_tool_call_handler_init( - common_tool_call_style style, - const common_chat_template & tmpl, - bool allow_content, - const nlohmann::ordered_json & parallel_tool_calls, - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - const nlohmann::ordered_json & json_schema = {}); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 939e6c36a..5e34fd9ee 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -117,8 +117,7 @@ struct slot_params { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - json oaicompat_tools; - common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE; + std::shared_ptr chat_parser; json to_json() const { std::vector samplers; @@ -166,7 +165,7 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, - {"grammar_trigger_words", sampling.grammar_trigger_words}, + // {"grammar_trigger_words", sampling.grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, @@ -212,6 +211,7 @@ struct server_task { static slot_params params_from_json_cmpl( const llama_context * ctx, const common_params & params_base, + const common_chat_template * tmpl, const json & data) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -322,6 +322,41 @@ struct server_task { } } + { + params.antiprompt.clear(); + const auto stop = data.find("stop"); + if (stop != data.end()) { + params.antiprompt = *stop; + } + } + + if (tmpl && params_base.use_jinja) { + common_chat_params chat_params; + chat_params.messages = json_value(data, "messages", json::array()); + chat_params.tools = json_value(data, "tools", json()); + chat_params.tool_choice = json_value(data, "tool_choice", std::string("auto")); + chat_params.json_schema = json_value(data, "json_schema", json()); + chat_params.parallel_tool_calls = json_value(data, "parallel_tool_calls", false); + chat_params.stream = json_value(data, "stream", false); + + auto chat_data = common_chat_init(*tmpl, chat_params); + params.chat_parser = std::move(chat_data.handler); + params.sampling.grammar = chat_data.grammar; + for (const auto & stop : chat_data.additional_stops) { + params.antiprompt.push_back(stop); + } + for (const auto & trigger : chat_data.grammar_triggers) { + auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]); + params.sampling.grammar_trigger_tokens.push_back(ids[0]); + continue; + } + LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str()); + params.sampling.grammar_trigger_words.push_back(trigger); + } + } + // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); @@ -336,13 +371,7 @@ struct server_task { } else { params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); } - - if (data.contains("tools")) { - params.oaicompat_tools = data.at("tools"); - } - if (data.contains("tool_call_style")) { - params.oaicompat_tool_call_style = data.at("tool_call_style"); - } + LOG_INF("Grammar: %s\n", params.sampling.grammar.c_str()); { params.sampling.logit_bias.clear(); @@ -379,45 +408,11 @@ struct server_task { } } - auto to_string_vec = [](const json & j) { - std::vector out; - if (j.is_array()) { - for (const auto & e : j) { - if (e.is_string()) { - out.push_back(e); - } - } - } - return out; - }; - - { - params.antiprompt.clear(); - const auto stop = data.find("stop"); - if (stop != data.end()) { - params.antiprompt = to_string_vec(*stop); - } - } - - { - const auto grammar_trigger_words = data.find("grammar_trigger_words"); - if (grammar_trigger_words != data.end()) { - for (const auto & word : to_string_vec(*grammar_trigger_words)) { - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - params.sampling.grammar_trigger_tokens.push_back(ids[0]); - continue; - } - params.sampling.grammar_trigger_words.push_back(word); - } - } - } - { const auto samplers = data.find("samplers"); if (samplers != data.end()) { if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false); + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); } else if (samplers->is_string()){ params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); } @@ -592,8 +587,7 @@ struct server_task_result_cmpl_final : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - json oaicompat_tools; - common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE; + common_chat_msg oaicompat_chat_msg; virtual int get_index() override { return index; @@ -688,39 +682,29 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat() { std::string finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; + finish_reason = oaicompat_chat_msg.tool_calls.empty() ? "stop" : "tool_calls"; } json tool_calls; - json message_content; - if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) { - auto parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); - if (!parsed_tool_calls.tool_calls.empty()) { - finish_reason = "tool_calls"; - message_content = parsed_tool_calls.content; - tool_calls = json::array(); - for (const auto & tc : parsed_tool_calls.tool_calls) { - tool_calls.push_back({ - {"type", "function"}, - {"function", { - {"name", tc.name}, - {"arguments", tc.arguments}, - }}, - {"id", tc.id.empty() ? json() : json(tc.id)}, - }); - } - } else { - message_content = parsed_tool_calls.content; + if (!oaicompat_chat_msg.tool_calls.empty()) { + tool_calls = json::array(); + for (const auto & tc : oaicompat_chat_msg.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id.empty() ? json() : json(tc.id)}, + }); } - } else { - message_content = content; } json choice { {"finish_reason", finish_reason}, {"index", 0}, {"message", json { - {"content", message_content}, + {"content", oaicompat_chat_msg.content}, {"tool_calls", tool_calls}, {"role", "assistant"}, }}, @@ -812,6 +796,8 @@ struct server_task_result_cmpl_partial : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_chat_msg; + std::shared_ptr chat_parser; virtual int get_index() override { return index; @@ -1234,6 +1220,8 @@ struct server_slot { std::string stopping_word; + std::shared_ptr chat_parser; + // sampling json json_schema; @@ -2260,6 +2248,10 @@ struct server_context { } void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send); + if (!opt_msg) { + return; + } auto res = std::make_unique(); res->id = slot.id_task; @@ -2275,6 +2267,7 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_msg = *opt_msg; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -2315,8 +2308,7 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_tools = slot.params.oaicompat_tools; - res->oaicompat_tool_call_style = slot.params.oaicompat_tool_call_style; + res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text); // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -3776,7 +3768,7 @@ int main(int argc, char ** argv) { std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat, - common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) { + const common_chat_template * tmpl) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3788,6 +3780,20 @@ int main(int argc, char ** argv) { std::vector tasks; try { + fprintf(stderr, "PROMPT: %s\n", data.at("prompt").get().c_str()); + std::string prompt; + if (tmpl && ctx_server.params_base.use_jinja) { + auto chat_data = common_chat_init(*tmpl, { + /* .messages = */ json_data(data, "messages", json::array()), + /* .tools = */ json_data(data, "tools", json()), + / + }); + + prompt = ctx_server.chat_templates.template_default->render(data.at("prompt").get()); + } else { + prompt = data.at("prompt").get(); + } + task.params.chat_parser = common_chat_init() std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -3800,12 +3806,14 @@ int main(int argc, char ** argv) { task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, + nullptr, data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; task.params.oaicompat_cmpl_id = completion_id; + task.params.chat_parser = common_chat_init() task.params.oaicompat_tools = json_value(data, "tools", json()); task.params.oaicompat_tool_call_style = tool_call_style; // oaicompat_model is already populated by params_from_json_cmpl @@ -3983,18 +3991,16 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - auto tool_call_style = common_tool_call_style_detect(chat_template); - LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str()); + LOG_INF("Request: %s\n", body.dump(2).c_str()); - json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja); + json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, req.is_connection_closed, res, - OAICOMPAT_TYPE_CHAT, - tool_call_style); + OAICOMPAT_TYPE_CHAT); }; const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1869ae7ab..b6e4e1def 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -17,8 +17,8 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" #include "minja.hpp" +#include "chat-handler.hpp" #include "chat-template.hpp" -#include "tool-call.h" #include #include @@ -581,24 +581,18 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ const common_chat_template & tmpl, - common_tool_call_style tool_call_style, bool use_jinja) { json llama_params; auto tools = json_value(body, "tools", json()); - auto has_tools = tools.is_array() && !tools.empty(); auto stream = json_value(body, "stream", false); - if (has_tools) { + if (tools.is_array() && !tools.empty()) { if (stream) { throw std::runtime_error("Cannot use tools with stream"); } - if (use_jinja) { - if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) { - throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); - } - } else { + if (!use_jinja) { throw std::runtime_error("tools param requires --jinja flag"); } } @@ -627,31 +621,15 @@ static json oaicompat_completion_params_parse( // Apply chat template to the list of messages if (use_jinja) { - bool allow_content = tool_choice != "required"; - if (tool_choice != "none" && has_tools) { - llama_params["tools"] = tools; - llama_params["tool_call_style"] = tool_call_style; - - auto parallel_tool_calls = body.contains("parallel_tool_calls") ? body.at("parallel_tool_calls") : json(); - llama_params["parallel_tool_calls"] = parallel_tool_calls; - - auto handler = common_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools, llama_params["json_schema"]); - llama_params["prompt"] = handler.prompt; - - for (const auto & stop : handler.additional_stops) { - llama_params["stop"].push_back(stop); - } - if (!handler.grammar_triggers.empty()) { - llama_params["grammar_trigger_words"] = handler.grammar_triggers; - } - if (!handler.grammar.empty()) { - if (llama_params.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - llama_params["grammar"] = handler.grammar; - } - } else { - llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + llama_params["tools"] = tools; + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { + throw std::runtime_error("Invalid tool_choice: " + tool_choice); + } + llama_params["tool_choice"] = tool_choice; + llama_params["parallel_tool_calls"] = json_value(body, "parallel_tool_calls", false); + if (tool_choice != "none" && llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); } } else { llama_params["prompt"] = format_chat(tmpl, body.at("messages")); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b1c43da98..61833292f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -133,7 +133,7 @@ llama_target_and_test(test-chat-template.cpp) # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-gguf.cpp) llama_target_and_test(test-backend-ops.cpp) -llama_target_and_test(test-tool-call.cpp) +llama_target_and_test(test-chat-handler.cpp) llama_target_and_test(test-model-load-cancel.cpp LABEL "model") llama_target_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-tool-call.cpp b/tests/test-chat-handler.cpp similarity index 74% rename from tests/test-tool-call.cpp rename to tests/test-chat-handler.cpp index a10bab605..cb42e9b49 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-chat-handler.cpp @@ -1,4 +1,4 @@ -#include "tool-call.h" +#include "chat-handler.hpp" #include "llama-grammar.h" #include "unicode.h" @@ -20,6 +20,7 @@ static void assert_equals(const T & expected, const T & actual) { } static std::string read_file(const std::string &path) { + std::cout << "# Reading: " << path << std::endl << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -76,10 +77,16 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool assert_equals(expected_content, result.content); auto tool_calls = json::array(); for (const auto & tc : result.tool_calls) { + auto arguments = tc.arguments; + try { + arguments = dump(json::parse(arguments)); + } catch (const std::exception & e) { + // ignore + } auto tool_call = json { {"type", "function"}, {"function", { - {"arguments", dump(json::parse(tc.arguments))}, + {"arguments", arguments}, {"name", tc.name}, }}, }; @@ -94,42 +101,44 @@ static void test_parse_tool_call(common_tool_call_style style, const json & tool assert_equals(expected, actual); } -const json tools = json::parse(R"([ - { - "type": "function", - "function": { - "name": "special_function", - "description": "I'm special", - "parameters": { - "type": "object", - "properties": { - "arg1": { - "type": "integer", - "description": "The arg." - } - }, - "required": ["arg1"] - } - } - }, - { - "type": "function", - "function": { - "name": "python", - "description": "a python interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The code." - } - }, - "required": ["code"] - } +const auto special_function_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] } } -])"); +})"); +const auto python_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "python", + "description": "an ipython interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + } + } +})"); +const auto code_interpreter_tool = json::parse(R"({ + "type": "code_interpreter" +})"); +const json tools = {special_function_tool, code_interpreter_tool}; static void test_parsing() { json request = { @@ -226,9 +235,7 @@ static void test_parsing() { {"type", "function"}, {"function", { {"name", "python"}, - {"arguments", dump({ - {"code", "this could be anything"} - })} + {"arguments", "this could be anything"}, }} }}); test_parse_tool_call(common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1, tools, @@ -238,7 +245,7 @@ static void test_parsing() { {"type", "function"}, {"function", { {"name", "python"}, - {"arguments", dump({{"code", ""}})} + {"arguments", ""}, }} }}); auto special_function_call = json { @@ -332,6 +339,8 @@ static void test_tool_call_style_detection() { } static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { + fprintf(stderr, "Template source: %s\n", tmpl.source().c_str()); + fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str()); auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); @@ -354,9 +363,7 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c return delta; } -static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { - std::cout << "# Testing template: " << template_file << std::endl << std::flush; - const common_chat_template tmpl(read_file(template_file), bos_token, eos_token); +static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { auto tool_call_style = common_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls"); @@ -404,24 +411,77 @@ static void test_grammars() { auto tool_call_message_with_id = json::parse(tool_call_message.dump()); tool_call_message_with_id["tool_calls"][0]["id"] = "123456789"; - test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "", "", { "" }, tool_call_message_with_id, tools, - /* skip_grammar_test= */ true); - test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); - test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); - test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "", "", { "<|im_end|>" }, tool_call_message, tools); - test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", "", "", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", "", "", { "<|eot_id|>" }, tool_call_message, tools); - test_template("tests/chat/templates/google-gemma-2-2b-it.jinja", "", "", { "" }, tool_call_message_with_id, tools); - test_template("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja", "", "", { "<|end|>" }, tool_call_message_with_id, tools); + auto python_tool_call_message = json { + {"role", "assistant"}, + {"content", ""}, + {"tool_calls", json {{ + {"type", "function"}, + {"function", { + {"name", "python"}, + {"arguments", "print('hey')"} + }}, + }}} + }; + + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + // test_template(tmpl, { "" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""); + // assert_equals(tmpl.requires_object_arguments_, true); + // test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); + // test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + // test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""); + // test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + // } + // { + // const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); + // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); + // } + { + const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "", ""); + test_template(tmpl, { "" }, tool_call_message_with_id, tools); + } + { + const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""); + test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools); + } } int main() { - test_tool_call_style_detection(); - test_parsing(); + // test_tool_call_style_detection(); + // test_parsing(); test_grammars(); std::cout << "\n[tool-call] All tests passed!" << std::endl;