From d23abdc3f605814a00ccd94534ac659ef43252a7 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 5 Feb 2025 21:08:23 +0000 Subject: [PATCH 1/3] When llama_chat_apply_template doesn't work Try minja. With granite-code if we fall back to jinja on failure, it's fine. Co-authored-by: Michael Engel Signed-off-by: Eric Curtin --- examples/run/run.cpp | 61 +++++++++++++++++++++++--------------- examples/server/server.cpp | 52 ++++++++++++++++++++------------ 2 files changed, 70 insertions(+), 43 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 39353ba30..d9cbf7923 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData & llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); } +// Function to handle Jinja template application +static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) { + json messages = json::array(); + for (const auto & msg : llama_data.messages) { + messages.push_back({ + { "role", msg.role }, + { "content", msg.content }, + }); + } + + try { + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.add_generation_prompt = append; + + minja::chat_template_options tmpl_opts; + tmpl_opts.use_bos_token = false; + tmpl_opts.use_eos_token = false; + + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); + } catch (const std::exception & e) { + printe("failed to render the chat template: %s\n", e.what()); + } + + return -1; +} + // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { if (use_jinja) { - json messages = json::array(); - for (const auto & msg : llama_data.messages) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); - } - try { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.add_generation_prompt = append; - - minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); - } catch (const std::exception & e) { - printe("failed to render the chat template: %s\n", e.what()); - return -1; - } + return handle_jinja_template(tmpl, llama_data, append); } + int result = llama_chat_apply_template( tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); + // If llama_chat_apply_template fails to apply template, fallback to using jinja + if (result < 0) { + return handle_jinja_template(tmpl, llama_data, append); + } + if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9cdf2058f..6a180b13d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1895,30 +1895,44 @@ struct server_context { return true; } + bool apply_jinja_templates() const { + auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({ + { + { "role", "user" }, + { "content", "test" }, + } + }); + GGML_ASSERT(templates.template_default); + try { + common_chat_params_init(*templates.template_default, inputs); + if (templates.template_tool_use) { + common_chat_params_init(*templates.template_tool_use, inputs); + } + + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + + return false; + } + } + bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; + llama_chat_message chat[] = { + { "user", "test" } + }; if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } + return apply_jinja_templates(); } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + if (chat_res < 0) { + return apply_jinja_templates(); + } + return chat_res > 0; } } From a4e9e4d41bffe8b841d0b33e01ea87f2045c217f Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 5 Feb 2025 22:59:10 +0000 Subject: [PATCH 2/3] Update examples/server/server.cpp Co-authored-by: Xuan-Son Nguyen --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6a180b13d..d0fee5742 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1895,7 +1895,7 @@ struct server_context { return true; } - bool apply_jinja_templates() const { + bool validate_jinja_templates() const { auto templates = common_chat_templates_from_model(model, ""); common_chat_inputs inputs; inputs.messages = json::array({ From e00c9d1c5e760c6b1d0f935fb80318d37b09cb46 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 5 Feb 2025 22:59:16 +0000 Subject: [PATCH 3/3] Update examples/server/server.cpp Co-authored-by: Xuan-Son Nguyen Signed-off-by: Eric Curtin --- examples/server/server.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d0fee5742..167e6a971 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1899,10 +1899,10 @@ struct server_context { auto templates = common_chat_templates_from_model(model, ""); common_chat_inputs inputs; inputs.messages = json::array({ - { - { "role", "user" }, - { "content", "test" }, - } + { + { "role", "user" }, + { "content", "test" }, + } }); GGML_ASSERT(templates.template_default); try { @@ -1925,12 +1925,12 @@ struct server_context { }; if (use_jinja) { - return apply_jinja_templates(); + return validate_jinja_templates(); } else { const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); if (chat_res < 0) { - return apply_jinja_templates(); + return validate_jinja_templates(); } return chat_res > 0;