From c395d4804fd72c8d5d2b65dfa6437e23d6d4eac9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 31 Oct 2024 13:45:10 +0000 Subject: [PATCH] `tool-call`: behaviour-based detection of template features --- common/chat-template.hpp | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 7e3932174..4dd381cef 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -32,22 +32,45 @@ class chat_template { std::string _eos_token; std::shared_ptr _template_root; + bool renders_needles( + const std::vector & needles, + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + auto prompt = apply(messages, tools, add_generation_prompt, extra_context); + for (const auto & needle : needles) { + if (prompt.find(needle) == std::string::npos) { + return false; + } + } + return true; + } catch (const std::exception & e) { + return false; + } + } + public: chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) : _source(source), _bos_token(bos_token), _eos_token(eos_token) { - _supports_tools = source.find("tools") != std::string::npos; - _requires_object_arguments = - source.find("tool_call.arguments | items") != std::string::npos - || source.find("tool_call.arguments | tojson") != std::string::npos; - _supports_system_role = source.find("System role not supported") == std::string::npos; - _supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos; - _template_root = minja::Parser::parse(_source, { /* .trim_blocks = */ true, /* .lstrip_blocks = */ true, /* .keep_trailing_newline = */ false, }); + _supports_tools = source.find("tools") != std::string::npos; + _requires_object_arguments = + source.find("tool_call.arguments | items") != std::string::npos + || source.find("tool_call.arguments | tojson") != std::string::npos; + _supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos; + + _supports_system_role = renders_needles({""}, { + {{"role", "system"}, {"content", ""}}, + {{"role", "user"}, {"content", "Hey"}} + }, {}, false); } const std::string & source() const { return _source; }