From 64263910d8497fe07a67998b411c9c56595a8e5b Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 29 Jan 2025 01:15:44 +0000 Subject: [PATCH] Fix firefunction w/ jinja: requires two variables, use the chat handlers everywhere templates are used --- common/chat-handler.cpp | 92 +++++++++++++++++++++----------------- common/common.cpp | 15 +++++-- examples/server/server.cpp | 16 +++---- 3 files changed, 71 insertions(+), 52 deletions(-) diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index ff905ee0b..bb13a6700 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -162,6 +162,14 @@ static void foreach_function(const json & tools, const std::function common_chat_msg { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .tool_calls = */ {}, - }; - }; + data.parser = no_op_text_parser; data.grammar_lazy = false; if (!params.json_schema.is_null()) { if (!params.grammar.empty()) { @@ -777,6 +788,10 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when return common_chat_init_functionary_v3_2_tool_call(tmpl, params); } + if (src.find(" functools[") != std::string::npos) { + // Firefunction v2 requires datetime and functions in the context + return common_chat_init_firefunction_v2_tool_call(tmpl, params); + } if (has_tools) { return common_chat_init_without_tools(tmpl, params); @@ -807,8 +822,5 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc if (src.find("[TOOL_CALLS]") != std::string::npos) { return common_chat_init_mistral_nemo_tool_call(tmpl, params); } - if (src.find(" functools[") != std::string::npos) { - return common_chat_init_firefunction_v2_tool_call(tmpl, params); - } return common_chat_init_generic_tool_call(tmpl, params); } diff --git a/common/common.cpp b/common/common.cpp index fa04d8a69..032754b8a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat-handler.hpp" #include "chat-template.hpp" #include @@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { - auto chat_template = minja::chat_template(tmpl, "", ""); - chat_template.apply({{ + auto chat_template = common_chat_template(tmpl, "", ""); + common_chat_params params; + params.messages = json::array({{ {"role", "user"}, {"content", "test"}, - }}, json(), true); + }}); + common_chat_init(chat_template, params); return true; } catch (const std::exception & e) { LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); @@ -1800,7 +1803,11 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - return tmpl.apply(messages, /* tools= */ json(), add_ass); + common_chat_params params; + params.messages = messages; + params.add_generation_prompt = add_ass; + auto data = common_chat_init(tmpl, params); + return data.prompt; } int alloc_size = 0; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b0db83a4c..03e95a78b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1794,17 +1794,16 @@ struct server_context { if (use_jinja) { auto templates = common_chat_templates_from_model(model, ""); + common_chat_params params; + params.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); GGML_ASSERT(templates.template_default); try { - templates.template_default->apply({{ - {"role", "user"}, - {"content", "test"}, - }}, json(), true); + common_chat_init(*templates.template_default, params); if (templates.template_tool_use) { - templates.template_tool_use->apply({{ - {"role", "user"}, - {"content", "test"}, - }}, json(), true); + common_chat_init(*templates.template_tool_use, params); } return true; } catch (const std::exception & e) { @@ -3770,6 +3769,7 @@ int main(int argc, char ** argv) { /* .stream = */ json_value(data, "stream", false), /* .grammar = */ json_value(data, "grammar", std::string("")), }); + LOG_INF("Chat format: %s\n", chat_data.format.c_str()); if (data.contains("grammar")) { if (!chat_data.grammar.empty()) { throw std::runtime_error("Cannot provide grammar and tools");