diff --git a/Makefile b/Makefile index 529fc6313..ef152d246 100644 --- a/Makefile +++ b/Makefile @@ -52,7 +52,7 @@ TEST_TARGETS = \ tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ - tests/test-chat-handler \ + tests/test-chat \ tests/test-chat-template \ tests/test-double-float \ tests/test-grammar-integration \ @@ -984,7 +984,7 @@ OBJ_COMMON = \ $(DIR_COMMON)/ngram-cache.o \ $(DIR_COMMON)/sampling.o \ $(DIR_COMMON)/speculative.o \ - $(DIR_COMMON)/chat-handler.o \ + $(DIR_COMMON)/chat.o \ $(DIR_COMMON)/build-info.o \ $(DIR_COMMON)/json-schema-to-grammar.o @@ -1363,8 +1363,8 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ - common/chat-handler.cpp \ - common/chat-handler.hpp \ + common/chat.cpp \ + common/chat.hpp \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ @@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -tests/test-chat-handler: tests/test-chat-handler.cpp \ +tests/test-chat: tests/test-chat.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0cfc8b3d0..72f0915c1 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,8 +56,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp - chat-handler.cpp - chat-handler.hpp + chat.cpp + chat.hpp chat-template.hpp common.cpp common.h diff --git a/common/chat-handler.cpp b/common/chat.cpp similarity index 83% rename from common/chat-handler.cpp rename to common/chat.cpp index aaef05dfd..2ed89459c 100644 --- a/common/chat-handler.cpp +++ b/common/chat.cpp @@ -1,4 +1,4 @@ -#include "chat-handler.hpp" +#include "chat.hpp" #include "chat-template.hpp" #include "json-schema-to-grammar.h" #include "log.h" @@ -170,11 +170,11 @@ static common_chat_msg no_op_text_parser(const std::string & input) { }; } -static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; +static common_chat_params common_chat_params_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; auto tool_call_schemas = json::array(); - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; auto tool_schema = json { {"type", "object"}, @@ -190,7 +190,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem if (function.contains("description")) { tool_schema["description"] = function["description"]; } - if (params.parallel_tool_calls) { + if (inputs.parallel_tool_calls) { tool_schema["properties"]["id"] = { {"type", "string"}, {"minLength", 4}, @@ -200,7 +200,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem tool_call_schemas.emplace_back(tool_schema); }); const auto tool_call = - params.parallel_tool_calls + inputs.parallel_tool_calls ? json { {"type", "object"}, {"properties", { @@ -224,16 +224,16 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem {"required", json::array({"tool_call"})}, }; const auto schema = - params.tool_choice != "required" + inputs.tool_choice != "required" ? json { {"anyOf", json::array({ tool_call, { {"type", "object"}, {"properties", { - {"response", params.json_schema.is_null() + {"response", inputs.json_schema.is_null() ? json {{"type", "string"}} - : params.json_schema + : inputs.json_schema }, }}, {"required", json::array({"response"})}, @@ -248,10 +248,10 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem }, grammar_options); auto tweaked_messages = common_chat_template::add_system( - params.messages, + inputs.messages, "Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n```json\n" + schema.dump(2) + "\n```"); - data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = "generic tool calls"; data.parser = [&](const std::string & input) { json data = json::parse(input); @@ -280,12 +280,12 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem return data; } -static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; - data.grammar_lazy = params.tool_choice != "required"; +static common_chat_params common_chat_params_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -311,13 +311,13 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, {"minItems", 1}, }; - if (!params.parallel_tool_calls) { + if (!inputs.parallel_tool_calls) { schema["maxItems"] = 1; } builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); }, grammar_options); data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = "mistral nemo tool calls"; data.parser = [](const std::string & input) { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); @@ -344,10 +344,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) { +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); - common_chat_data data; - data.grammar_lazy = params.tool_choice != "required"; + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; @@ -380,7 +380,7 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_ }; auto has_function = false; - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -410,12 +410,12 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_ builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, { + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : ""); - data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg { + data.parser = [inputs, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg { static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); static std::regex close_regex("\\}"); static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); @@ -447,17 +447,17 @@ static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_ }; } } - return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); + return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); }; return data; } -static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; - data.grammar_lazy = params.tool_choice != "required"; +static common_chat_params common_chat_params_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -466,27 +466,27 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); }); data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space"); + builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = "deepseek r1 tool calls"; - data.parser = [params](const std::string & input) { + data.parser = [inputs](const std::string & input) { static std::regex trigger_regex("<|tool▁calls▁begin|>"); static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); static std::regex close_regex("```<|tool▁call▁end|>"); - return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true); + return parse_json_tool_calls(inputs.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true); }; return data; } -static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_params common_chat_params_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { fprintf(stderr, "%s\n", __func__); - common_chat_data data; - if (!params.tools.is_null() && !params.tools.empty()) { - data.grammar_lazy = params.tool_choice != "required"; + common_chat_params data; + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; schemas.push_back({ {"type", "object"}, @@ -505,7 +505,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, {"minItems", 1}, }; - if (!params.parallel_tool_calls) { + if (!inputs.parallel_tool_calls) { schema["maxItems"] = 1; } builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); @@ -519,24 +519,24 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ data.parser = no_op_text_parser; data.format = "firefunction v2 text-only"; } - data.prompt = tmpl.apply(params.messages, /* tools= */ nullptr, params.add_generation_prompt, { + data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, - {"functions", json(params.tools.empty() ? "" : params.tools.dump(2))}, + {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }, /* adjust_inputs= */ false); return data; } -static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar - common_chat_data data; + common_chat_params data; - if (!params.tools.is_null() && !params.tools.empty()) { - data.grammar_lazy = params.tool_choice != "required"; + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -547,7 +547,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); }); auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - if (params.parallel_tool_calls) { + if (inputs.parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); } else { @@ -560,12 +560,12 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common data.format = "functionary v3.2 content-only"; } - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); - data.parser = [params](const std::string & input) { + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.parser = [inputs](const std::string & input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); - auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true); + auto res = parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true, /* allow_raw_python= */ true); if (res.content.find("all\n") == 0) { res.content = res.content.substr(4); } @@ -574,17 +574,17 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common return data; } -static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt - common_chat_data data; - json tools = params.tools.is_null() ? params.tools : json::array(); + common_chat_params data; + json tools = inputs.tools.is_null() ? inputs.tools : json::array(); std::string python_code_argument_name; auto has_raw_python = false; - data.grammar_lazy = params.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; const auto & parameters = function["parameters"]; std::string name = function["name"]; @@ -618,13 +618,13 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); } auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; - builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({" common_chat_msg { + data.parser = [inputs, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -644,18 +644,18 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); - return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python); + return parse_json_tool_calls(inputs.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python); }; return data; } -static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; +static common_chat_params common_chat_params_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = params.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - foreach_function(params.tools, [&](const json & tool) { + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; @@ -670,11 +670,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha })); }); auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; - builder.add_rule("root", params.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({"", /* .at_start = */ false}); }, grammar_options); - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = "hermes 2 pro tool calls"; data.parser = [&](const std::string & input) -> common_chat_msg { try { @@ -733,60 +733,60 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha return data; } -static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) { - common_chat_data data; - data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = "content-only"; data.parser = no_op_text_parser; data.grammar_lazy = false; - if (!params.json_schema.is_null()) { - if (!params.grammar.empty()) { + if (!inputs.json_schema.is_null()) { + if (!inputs.grammar.empty()) { throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); } - data.grammar = json_schema_to_grammar(params.json_schema); + data.grammar = json_schema_to_grammar(inputs.json_schema); } else { - data.grammar = params.grammar.empty(); + data.grammar = inputs.grammar.empty(); } return data; } -common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) { - auto has_tools = params.tools.is_null() || params.tool_choice == "none"; - if (has_tools && !params.grammar.empty()) { +common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + auto has_tools = inputs.tools.is_null() || inputs.tool_choice == "none"; + if (has_tools && !inputs.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } const auto & src = tmpl.source(); if (src.find(">>>all") != std::string::npos) { // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when - return common_chat_init_functionary_v3_2_tool_call(tmpl, params); + return common_chat_params_init_functionary_v3_2(tmpl, inputs); } if (src.find(" functools[") != std::string::npos) { - // Firefunction v2 requires datetime and functions in the context - return common_chat_init_firefunction_v2_tool_call(tmpl, params); + // Firefunction v2 requires datetime and functions in the context, even w/o tools. + return common_chat_params_init_firefunction_v2_tool_call(tmpl, inputs); } if (has_tools) { - return common_chat_init_without_tools(tmpl, params); + return common_chat_params_init_without_tools(tmpl, inputs); } if (src.find("") != std::string::npos) { - return common_chat_init_hermes_2_pro_tool_call(tmpl, params); + return common_chat_params_init_hermes_2_pro_tool_call(tmpl, inputs); } if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); + return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); } if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { - return common_chat_init_deepseek_r1_tool_call(tmpl, params); + return common_chat_params_init_deepseek_r1_tool_call(tmpl, inputs); } if (src.find("[TOOL_CALLS]") != std::string::npos) { - return common_chat_init_mistral_nemo_tool_call(tmpl, params); + return common_chat_params_init_mistral_nemo_tool_call(tmpl, inputs); } - return common_chat_init_generic_tool_call(tmpl, params); + return common_chat_params_init_generic_tool_call(tmpl, inputs); } diff --git a/common/chat-handler.hpp b/common/chat.hpp similarity index 68% rename from common/chat-handler.hpp rename to common/chat.hpp index 24b96706c..3ca2c54e3 100644 --- a/common/chat-handler.hpp +++ b/common/chat.hpp @@ -1,11 +1,5 @@ -/* - Copyright 2024 Google LLC +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. - Use of this source code is governed by an MIT-style - license that can be found in the LICENSE file or at - https://opensource.org/licenses/MIT. -*/ -// SPDX-License-Identifier: MIT #pragma once #include "common.h" @@ -16,7 +10,7 @@ using json = nlohmann::ordered_json; -struct common_chat_params { +struct common_chat_inputs { json messages; json tools; json tool_choice; @@ -29,7 +23,7 @@ struct common_chat_params { typedef std::function common_chat_parser; -struct common_chat_data { +struct common_chat_params { json prompt; std::string grammar; std::vector grammar_triggers; @@ -39,4 +33,4 @@ struct common_chat_data { bool grammar_lazy = false; }; -struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params); +struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); diff --git a/common/common.cpp b/common/common.cpp index 032754b8a..72b6491f1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,7 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" -#include "chat-handler.hpp" +#include "chat.hpp" #include "chat-template.hpp" #include @@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_params params; + common_chat_inputs params; params.messages = json::array({{ {"role", "user"}, {"content", "test"}, }}); - common_chat_init(chat_template, params); + common_chat_params_init(chat_template, params); return true; } catch (const std::exception & e) { LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); @@ -1803,10 +1803,10 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_params params; + common_chat_inputs params; params.messages = messages; params.add_generation_prompt = add_ass; - auto data = common_chat_init(tmpl, params); + auto data = common_chat_params_init(tmpl, params); return data.prompt; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c5ba7c2b2..1fe79c244 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1824,16 +1824,16 @@ struct server_context { if (use_jinja) { auto templates = common_chat_templates_from_model(model, ""); - common_chat_params params; + common_chat_inputs params; params.messages = json::array({{ {"role", "user"}, {"content", "test"}, }}); GGML_ASSERT(templates.template_default); try { - common_chat_init(*templates.template_default, params); + common_chat_params_init(*templates.template_default, params); if (templates.template_tool_use) { - common_chat_init(*templates.template_tool_use, params); + common_chat_params_init(*templates.template_tool_use, params); } return true; } catch (const std::exception & e) { @@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) { std::vector tasks; try { - common_chat_data chat_data; + common_chat_params chat_data; bool add_special = false; if (tmpl && ctx_server.params_base.use_jinja) { - chat_data = common_chat_init(*tmpl, { + chat_data = common_chat_params_init(*tmpl, { /* .messages = */ json_value(data, "messages", json::array()), /* .tools = */ json_value(data, "tools", json()), /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")), diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 7593b4691..74667bf46 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -17,7 +17,7 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" #include "minja.hpp" -#include "chat-handler.hpp" +#include "chat.hpp" #include "chat-template.hpp" #include diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96c38789e..40f83ff0d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -93,7 +93,7 @@ if (NOT WIN32) llama_target_and_test(test-grammar-parser.cpp) llama_target_and_test(test-grammar-integration.cpp) llama_target_and_test(test-llama-grammar.cpp) - llama_target_and_test(test-chat-handler.cpp) + llama_target_and_test(test-chat.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tests/test-chat-handler.cpp b/tests/test-chat.cpp similarity index 97% rename from tests/test-chat-handler.cpp rename to tests/test-chat.cpp index ecdd02bc8..74f6bd1be 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat.cpp @@ -1,4 +1,4 @@ -#include "chat-handler.hpp" +#include "chat.hpp" #include "chat-template.hpp" #include "llama-grammar.h" #include "unicode.h" @@ -169,15 +169,15 @@ struct delta_data { }; static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { - common_chat_params params; + common_chat_inputs params; params.parallel_tool_calls = true; params.messages = json::array(); params.messages.push_back(user_message); params.tools = tools; - auto prefix_data = common_chat_init(tmpl, params); + auto prefix_data = common_chat_params_init(tmpl, params); params.messages.push_back(delta_message); params.add_generation_prompt = false; - auto full_data = common_chat_init(tmpl, params); + auto full_data = common_chat_params_init(tmpl, params); std::string prefix = prefix_data.prompt; std::string full = full_data.prompt; @@ -220,7 +220,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector", ""), tools_params)); // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( + assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser( "{\n" " \"response\": \"Hello, world!\"\n" "}"));