diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 39353ba30..d9cbf7923 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData & llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); } +// Function to handle Jinja template application +static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) { + json messages = json::array(); + for (const auto & msg : llama_data.messages) { + messages.push_back({ + { "role", msg.role }, + { "content", msg.content }, + }); + } + + try { + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.add_generation_prompt = append; + + minja::chat_template_options tmpl_opts; + tmpl_opts.use_bos_token = false; + tmpl_opts.use_eos_token = false; + + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); + } catch (const std::exception & e) { + printe("failed to render the chat template: %s\n", e.what()); + } + + return -1; +} + // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { if (use_jinja) { - json messages = json::array(); - for (const auto & msg : llama_data.messages) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); - } - try { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.add_generation_prompt = append; - - minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); - } catch (const std::exception & e) { - printe("failed to render the chat template: %s\n", e.what()); - return -1; - } + return handle_jinja_template(tmpl, llama_data, append); } + int result = llama_chat_apply_template( tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); + // If llama_chat_apply_template fails to apply template, fallback to using jinja + if (result < 0) { + return handle_jinja_template(tmpl, llama_data, append); + } + if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9cdf2058f..167e6a971 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1895,30 +1895,44 @@ struct server_context { return true; } + bool validate_jinja_templates() const { + auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({ + { + { "role", "user" }, + { "content", "test" }, + } + }); + GGML_ASSERT(templates.template_default); + try { + common_chat_params_init(*templates.template_default, inputs); + if (templates.template_tool_use) { + common_chat_params_init(*templates.template_tool_use, inputs); + } + + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + + return false; + } + } + bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; + llama_chat_message chat[] = { + { "user", "test" } + }; if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } + return validate_jinja_templates(); } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + if (chat_res < 0) { + return validate_jinja_templates(); + } + return chat_res > 0; } }