From 3c7784c51cd0c075f8d6586e2a09850fd0842792 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 00:13:16 +0000 Subject: [PATCH] Refactor common_chat_* functions to accept minja template + use_jinja option --- common/arg.cpp | 2 +- common/common.cpp | 41 ++++++++++++++++++------------------ common/common.h | 27 +++++++++++++----------- common/tool-call.cpp | 4 ++-- common/tool-call.h | 4 ++-- examples/main/main.cpp | 24 ++++++++++----------- examples/run/run.cpp | 4 ++-- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 8 +++---- tests/test-chat-template.cpp | 14 ++++++------ tests/test-tool-call.cpp | 6 +++--- 11 files changed, 71 insertions(+), 65 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index c379e78ef..cb43b0d52 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1919,7 +1919,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.use_jinja = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/common.cpp b/common/common.cpp index 1538cfcab..17659f7a7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1787,10 +1787,19 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const llama_chat_template & tmpl, const std::vector & msgs, - bool add_ass) { + bool add_ass, + bool use_jinja) { + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + return tmpl.apply(messages, /* tools= */ json(), add_ass); + } + int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; @@ -1799,7 +1808,7 @@ std::string common_chat_apply_template(const struct llama_model * model, alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str(); + const char * ptr_tmpl = tmpl.source().c_str(); std::vector buf(alloc_size); // run the first time to get the total output length @@ -1830,13 +1839,14 @@ std::string common_chat_apply_template(const struct llama_model * model, return formatted_chat; } -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const llama_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass) { + bool add_ass, + bool use_jinja) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1844,29 +1854,20 @@ std::string common_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - const auto add_generation_prompt = true; - if (use_jinja) { - auto messages = json::array(); - for (const auto & msg : msgs) { - messages.push_back({{"role", msg.role}, {"content", msg.content}}); - } - return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt); - } else { - return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt); - } + return common_chat_apply_template(tmpl, msgs, true, use_jinja); } llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) diff --git a/common/common.h b/common/common.h index 5bb8946bd..7cd7389a7 100644 --- a/common/common.h +++ b/common/common.h @@ -607,34 +607,37 @@ struct common_chat_msg { // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); +typedef minja::chat_template llama_chat_template; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const llama_chat_template & tmpl, const std::vector & chat, - bool add_ass); + bool add_ass, + bool use_jinja); // Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const llama_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass); + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string common_chat_format_example(const struct llama_model * model, - const minja::chat_template & tmpl, bool use_jinja); - +std::string common_chat_format_example( + const llama_chat_template & tmpl, bool use_jinja); struct llama_chat_templates { - minja::chat_template default_template; - std::optional tool_use_template; + llama_chat_template default_template; + std::optional tool_use_template; }; llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); -minja::chat_template llama_chat_template_from_model( +llama_chat_template llama_chat_template_from_model( const struct llama_model * model, const std::string & chat_template_override = "", bool prefer_tool_use = false); diff --git a/common/tool-call.cpp b/common/tool-call.cpp index bc0de8ab2..26bb60479 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -74,7 +74,7 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) { } } -llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) { +llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template) { const auto & src = chat_template.source(); if (src.find("") != std::string::npos) { @@ -399,7 +399,7 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, - const minja::chat_template & tmpl, + const llama_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, diff --git a/common/tool-call.h b/common/tool-call.h index 2a9c3cf9e..f96ed2b1f 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -41,13 +41,13 @@ struct llama_tool_call_handler { std::string llama_tool_call_style_name(llama_tool_call_style style); -llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template); +llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, - const minja::chat_template & tmpl, + const llama_chat_template & tmpl, bool allow_content, const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & messages, diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 264e77629..f72325d77 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -84,14 +84,6 @@ static void sigint_handler(int signo) { } #endif -static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); - chat_msgs.push_back({role, content}); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; -} - int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -226,7 +218,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -270,10 +262,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + common_chat_msg new_msg{role, content}; + auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back({role, content}); + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + return formatted; + }; + { auto prompt = (params.conversation_mode && params.enable_chat_template) // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { @@ -766,7 +766,7 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str()); } is_interacting = true; LOG("\n"); @@ -831,7 +831,7 @@ int main(int argc, char ** argv) { bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)) : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index b4cbed9be..64cc2d20d 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { +static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { if (use_jinja) { json messages = json::array(); for (const auto & msg : llama_data.messages) { @@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { +static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a483b9a26..b1028eae5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4389,7 +4389,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, get_chat_templates().default_template.source().c_str(), - common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8f9a7517c..ccb384506 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -352,7 +352,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +inline std::string format_chat(const struct llama_model * model, const llama_chat_template & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -381,7 +381,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.push_back({role, content}); } - const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -582,7 +582,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const minja::chat_template & tmpl, + const llama_chat_template & tmpl, llama_tool_call_style tool_call_style, bool use_jinja) { @@ -673,7 +673,7 @@ static json oaicompat_completion_params_parse( llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); } } else { - llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); + llama_params["prompt"] = format_chat(model, tmpl, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 9560d4fa3..3bd11a1f0 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -319,9 +319,10 @@ int main(void) { std::vector chat2; common_chat_msg sys_msg{"system", "You are a helpful assistant"}; - auto fmt_sys = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); - printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_sys = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; @@ -345,9 +346,10 @@ int main(void) { chat2.push_back({"assistant", "I am assistant"}); common_chat_msg new_msg{"user", "How are you"}; - auto fmt_single = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true); - printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_single = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 329393877..2230bfa65 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -311,7 +311,7 @@ static void test_parsing() { } static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { - const minja::chat_template tmpl(read_file(template_file), "", ""); + const llama_chat_template tmpl(read_file(template_file), "", ""); auto tool_call_style = llama_tool_call_style_detect(tmpl); std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; assert_equals(expected, tool_call_style); @@ -331,7 +331,7 @@ static void test_tool_call_style_detection() { test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic); } -static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { +static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); @@ -356,7 +356,7 @@ static std::string get_message_prompt_delta(const minja::chat_template & tmpl, c static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { std::cout << "# Testing template: " << template_file << std::endl << std::flush; - const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token); + const llama_chat_template tmpl(read_file(template_file), bos_token, eos_token); auto tool_call_style = llama_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls");