diff --git a/common/common.cpp b/common/common.cpp index de2adba1f..8d6e8b0cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1831,12 +1831,15 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { } std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, const common_params_tools & tools) { + const auto & tmpl_selected = + tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + if (use_jinja) { common_chat_inputs inputs; @@ -1844,9 +1847,11 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { inputs.tools = *tools.tools(); } + auto choice = tools.choice(); if (std::holds_alternative(choice)) { inputs.tool_choice = std::get(choice); @@ -1857,9 +1862,10 @@ std::string common_chat_apply_template( inputs.tool_choice = *choice_ptr; } } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; + return common_chat_params_init(tmpl_selected, inputs).prompt; } int alloc_size = 0; @@ -1872,7 +1878,7 @@ std::string common_chat_apply_template( std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1884,7 +1890,7 @@ std::string common_chat_apply_template( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); @@ -1892,7 +1898,7 @@ std::string common_chat_apply_template( } std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -1916,7 +1922,7 @@ std::string common_chat_format_single( return ss.str(); } -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant", {}}, {"user", "Hello", {}}, diff --git a/common/common.h b/common/common.h index 417eb546c..c3f4f66ee 100644 --- a/common/common.h +++ b/common/common.h @@ -675,7 +675,7 @@ struct common_chat_templates { // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & chat, bool add_ass, bool use_jinja, @@ -683,7 +683,7 @@ std::string common_chat_apply_template( // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -692,7 +692,7 @@ std::string common_chat_format_single( // Returns an example of formatted chat std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); + const common_chat_templates & tmpl, bool use_jinja); common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d19a24331..d562522a4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -268,9 +268,9 @@ int main(int argc, char ** argv) { const common_params_tools & tools = common_params_tools()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single( - *chat_templates.template_default, chat_msgs, new_msg, role == "user", - g_params->use_jinja, tools); + + auto formatted = common_chat_format_single(chat_templates, chat_msgs, + new_msg, role == "user", g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e0acc4705..3ceba0558 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4468,7 +4468,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fd..f1b4ee5b5 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -348,7 +348,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_templates & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse( llama_params["stop"].push_back(stop); } } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + llama_params["prompt"] = format_chat(chat_templates, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d..022205b7b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -339,7 +339,8 @@ int main(void) { common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); @@ -366,7 +367,8 @@ int main(void) { common_chat_msg new_msg{"user", "How are you", {}}; auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n");