From d7b31a9d84297b493a61c9a8ee3a458e7ccc64a7 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 10 Feb 2025 09:34:09 +0000 Subject: [PATCH] sync: minja (https://github.com/google/minja/commit/a72057e5190de2c612d4598bb10b4bfd0f53011f) (#11774) --- common/chat-template.hpp | 28 +++++++++++++++++++++------- common/minja.hpp | 33 ++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 0e88fb361..882ba41bd 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -249,16 +249,30 @@ class chat_template { inputs.add_generation_prompt = false; full = apply(inputs); } - - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); + } + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; } - if (full.find(prefix) != 0) { + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; } - tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); @@ -363,7 +377,7 @@ class chat_template { if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + - (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); } else { adjusted_messages = inputs.messages; } diff --git a/common/minja.hpp b/common/minja.hpp index c304b5c66..c58dd66e0 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) { return s.substr(start, end - start + 1); } +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + static std::string html_escape(const std::string & s) { std::string result; result.reserve(s.size()); @@ -1462,6 +1469,9 @@ public: if (method->get_name() == "strip") { vargs.expectArgs("strip method", {0, 0}, {0, 0}); return Value(strip(str)); + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); @@ -1792,7 +1802,7 @@ private: auto left = parseStringConcat(); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); - static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); static std::regex not_tok(R"(not\b)"); std::string op_str; while (!(op_str = consumeToken(compare_tok)).empty()) { @@ -2171,7 +2181,7 @@ private: using TemplateTokenIterator = TemplateTokenVector::const_iterator; std::vector parseVarNames() { - static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); std::vector group; if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); @@ -2194,13 +2204,13 @@ private: } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); - static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); - static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); - static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); TemplateTokenVector tokens; std::vector group; @@ -2284,7 +2294,7 @@ private: auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "set") { - static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); std::string ns; std::vector var_names; @@ -2336,6 +2346,11 @@ private: throw std::runtime_error("Unexpected block: " + keyword); } } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + throw std::runtime_error("Internal error: Expected a comment"); + throw std::runtime_error("Missing end of comment tag"); + } auto text_end = it + match.position(); text = std::string(it, text_end); it = text_end; @@ -2400,7 +2415,7 @@ private: auto text = text_token->text; if (post_space == SpaceHandling::Strip) { - static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + static std::regex trailing_space_regex(R"(\s+$)"); text = std::regex_replace(text, trailing_space_regex, ""); } else if (options.lstrip_blocks && it != end) { auto i = text.size(); @@ -2410,7 +2425,7 @@ private: } } if (pre_space == SpaceHandling::Strip) { - static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { if (text.length() > 0 && text[0] == '\n') {