diff --git a/common/common.cpp b/common/common.cpp index 8009601de..b390f1df3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1811,15 +1811,23 @@ std::string common_chat_format_single(const struct llama_model * model, return ss.str(); } -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { +std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return common_chat_apply_template(model, tmpl, msgs, true); + const auto add_generation_prompt = true; + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt); + } else { + return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt); + } } llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) diff --git a/common/common.h b/common/common.h index dea779d09..24a91cfa9 100644 --- a/common/common.h +++ b/common/common.h @@ -619,7 +619,7 @@ std::string common_chat_format_single(const struct llama_model * model, // Returns an example of formatted chat std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); + const minja::chat_template & tmpl, bool use_jinja); struct llama_chat_templates { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39666a0e8..11038a7c6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -165,6 +165,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); + auto chat_templates = llama_chat_templates_from_model(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -207,7 +208,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty(); + const bool has_chat_template = !chat_templates.default_template.source().empty(); if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -225,7 +226,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 15bcb7e0e..dc302ddc1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4287,8 +4287,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), - common_chat_format_example(ctx_server.model, params.chat_template).c_str()); + get_chat_templates().default_template.source().c_str(), + common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1));