diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 802132b47..f78e169ac 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -152,50 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in return result; } -class text_chat_parser : public common_chat_parser { -public: - std::optional parse_partial(const std::string & input) override { - return parse_final(input); - } - - common_chat_msg parse_final(const std::string & input) override { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; - } - - std::unique_ptr clone() const override { - return std::make_unique(); - } -}; - -class monolithic_chat_parser : public common_chat_parser { - - std::string input_buffer_; - std::function parse_final_; - -public: - monolithic_chat_parser(const std::function & parse_final) : parse_final_(parse_final) {} - - std::optional parse_partial(const std::string & input) override { - input_buffer_ += input; - return std::nullopt; - } - - common_chat_msg parse_final(const std::string & input) override { - input_buffer_ += input; - auto out = parse_final_(input_buffer_); - input_buffer_.clear(); - return out; - } - - std::unique_ptr clone() const override { - return std::make_unique(parse_final_); - } -}; - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { @@ -289,7 +245,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem data.prompt = tmpl.apply(tweaked_messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "generic tool calls"; - data.parser = std::make_unique([&](const std::string & input) { + data.parser = [&](const std::string & input) { json data = json::parse(input); common_chat_msg result; result.role = "assistant"; @@ -312,7 +268,7 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem result.content = response.is_string() ? response.get() : response.dump(2); } return result; - }); + }; return data; } @@ -355,9 +311,9 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "mistral nemo tool calls"; - data.parser = std::make_unique([](const std::string & input) { + data.parser = [](const std::string & input) { return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); - }); + }; return data; } @@ -441,7 +397,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); data.format = "llama 3.1 tool calls"; - data.parser = std::make_unique([params](const std::string & input) -> common_chat_msg { + data.parser = [params](const std::string & input) -> common_chat_msg { static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); @@ -472,7 +428,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c }; } return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); - }); + }; return data; } @@ -505,12 +461,12 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ data.additional_stops.push_back("<|eom_id|>"); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {}); data.format = "llama 3.2 tool calls"; - data.parser = std::make_unique([params](const std::string & input) { + data.parser = [params](const std::string & input) { static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); static std::regex close_regex("\\}"); auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true); return res; - }); + }; return data; } @@ -532,12 +488,12 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat }, grammar_options); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "deepseek r1 tool calls"; - data.parser = std::make_unique([params](const std::string & input) { + data.parser = [params](const std::string & input) { static std::regex trigger_regex("<|tool▁calls▁begin|>"); static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); static std::regex close_regex("```<|tool▁call▁end|>"); return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true); - }); + }; return data; } @@ -573,9 +529,9 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_ data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "firefunction v2 tool calls"; - data.parser = std::make_unique([](const std::string & input) { + data.parser = [](const std::string & input) { return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); - }); + }; return data; } @@ -610,7 +566,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "functionary v3.2 tool calls"; - data.parser = std::make_unique([params](const std::string & input) { + data.parser = [params](const std::string & input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); @@ -619,7 +575,7 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common res.content = res.content.substr(4); } return res; - }); + }; return data; } @@ -674,7 +630,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "functionary v3.1 llama 3.1 tool calls"; - data.parser = std::make_unique([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg { + data.parser = [params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; @@ -695,7 +651,7 @@ static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(co static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ false, has_raw_python); - }); + }; return data; } @@ -726,7 +682,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "hermes 2 pro tool calls"; - data.parser = std::make_unique([&](const std::string & input) -> common_chat_msg { + data.parser = [&](const std::string & input) -> common_chat_msg { try { std::regex start_pattern(R"([\n\s]*)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); @@ -779,7 +735,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha /* .tool_calls = */ {}, }; } - }); + }; return data; } @@ -787,7 +743,13 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat common_chat_data data; data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt); data.format = "content-only"; - data.parser = std::make_unique(); + data.parser = [](const std::string & input) -> common_chat_msg { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + }; data.grammar_lazy = false; if (!params.json_schema.is_null()) { if (!params.grammar.empty()) { diff --git a/common/chat-handler.hpp b/common/chat-handler.hpp index e640112a2..24b96706c 100644 --- a/common/chat-handler.hpp +++ b/common/chat-handler.hpp @@ -27,21 +27,14 @@ struct common_chat_params { bool add_generation_prompt = true; }; -class common_chat_parser { -public: - virtual ~common_chat_parser() = default; - - virtual std::optional parse_partial(const std::string & input) = 0; - virtual common_chat_msg parse_final(const std::string & input) = 0; - virtual std::unique_ptr clone() const = 0; -}; +typedef std::function common_chat_parser; struct common_chat_data { json prompt; std::string grammar; std::vector grammar_triggers; - std::vector additional_stops; - std::unique_ptr parser; + std::vector additional_stops;// std::unique_ptr parser; + common_chat_parser parser; std::string format; // For debugging and testing. bool grammar_lazy = false; }; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8f4aca6a8..b0db83a4c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -117,7 +117,7 @@ struct slot_params { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - std::shared_ptr chat_parser; + common_chat_parser chat_parser; json to_json() const { std::vector samplers; @@ -768,7 +768,6 @@ struct server_task_result_cmpl_partial : server_task_result { oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_chat_msg; std::shared_ptr chat_parser; virtual int get_index() override { @@ -2220,16 +2219,6 @@ struct server_context { } void send_partial_response(server_slot & slot, const completion_token_output & tkn) { - common_chat_msg msg; - if (slot.params.chat_parser) { - if (auto opt_msg = slot.params.chat_parser->parse_partial(tkn.text_to_send)) { - msg = *opt_msg; - } else { - return; - } - } else { - msg.content = tkn.text_to_send; - } auto res = std::make_unique(); res->id = slot.id_task; @@ -2245,7 +2234,6 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_chat_msg = msg; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -2286,18 +2274,14 @@ struct server_context { res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - if (!slot.params.chat_parser) { + if (slot.params.chat_parser) { + res->oaicompat_chat_msg = slot.params.chat_parser(slot.generated_text); + } else { res->oaicompat_chat_msg = { /* .role = */ "assistant", /* .content = */ slot.generated_text, /* .tool_calls = */ {} }; - } else if (slot.stop == STOP_TYPE_LIMIT) { - if (auto opt_msg = slot.params.chat_parser->parse_partial(slot.generated_text)) { - res->oaicompat_chat_msg = *opt_msg; - } - } else { - res->oaicompat_chat_msg = slot.params.chat_parser->parse_final(slot.generated_text); } // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -3835,9 +3819,7 @@ int main(int argc, char ** argv) { task.params.sampling.grammar_trigger_words.push_back(trigger); } task.params.antiprompt = chat_data.additional_stops; - if (chat_data.parser) { - task.params.chat_parser = i == tokenized_prompts.size() ? std::move(chat_data.parser) : std::move(chat_data.parser->clone()); - } + task.params.chat_parser = chat_data.parser; if (task.params.sampling.grammar_lazy) { GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0); } diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 079dcac82..309f74d56 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -397,7 +397,7 @@ static void test_template(const common_chat_template & tmpl, const std::vectorparse_final(full_delta); + const auto msg = chat_data.parser(full_delta); assert_msg_equals(expected_msg, msg); auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {