Refactor common_chat_* functions to accept minja template + use_jinja option

This commit is contained in:
ochafik 2025-01-18 00:13:16 +00:00
parent d47f40caea
commit 3c7784c51c
11 changed files with 71 additions and 65 deletions

View file

@ -1919,7 +1919,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.use_jinja = true; params.use_jinja = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg( add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE", {"--chat-template"}, "JINJA_TEMPLATE",
string_format( string_format(

View file

@ -1787,10 +1787,19 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
return res >= 0; return res >= 0;
} }
std::string common_chat_apply_template(const struct llama_model * model, std::string common_chat_apply_template(
const std::string & tmpl, const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs, const std::vector<common_chat_msg> & msgs,
bool add_ass) { bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
}
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
@ -1799,7 +1808,7 @@ std::string common_chat_apply_template(const struct llama_model * model,
alloc_size += (msg.role.size() + msg.content.size()) * 1.25; alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
} }
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str(); const char * ptr_tmpl = tmpl.source().c_str();
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
@ -1830,13 +1839,14 @@ std::string common_chat_apply_template(const struct llama_model * model,
return formatted_chat; return formatted_chat;
} }
std::string common_chat_format_single(const struct llama_model * model, std::string common_chat_format_single(
const std::string & tmpl, const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg, const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg, const common_chat_msg & new_msg,
bool add_ass) { bool add_ass,
bool use_jinja) {
std::ostringstream ss; std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg); std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version // if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1844,29 +1854,20 @@ std::string common_chat_format_single(const struct llama_model * model,
}; };
// format chat with new_msg // format chat with new_msg
chat_new.push_back(new_msg); chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part // get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str(); return ss.str();
} }
std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = { std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"}, {"system", "You are a helpful assistant"},
{"user", "Hello"}, {"user", "Hello"},
{"assistant", "Hi there"}, {"assistant", "Hi there"},
{"user", "How are you?"}, {"user", "How are you?"},
}; };
const auto add_generation_prompt = true; return common_chat_apply_template(tmpl, msgs, true, use_jinja);
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt);
} else {
return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt);
}
} }
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)

View file

@ -607,34 +607,37 @@ struct common_chat_msg {
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
typedef minja::chat_template llama_chat_template;
// CPP wrapper for llama_chat_apply_template // CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml // If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error // If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model, std::string common_chat_apply_template(
const std::string & tmpl, const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & chat, const std::vector<common_chat_msg> & chat,
bool add_ass); bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history // Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(const struct llama_model * model, std::string common_chat_format_single(
const std::string & tmpl, const llama_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg, const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg, const common_chat_msg & new_msg,
bool add_ass); bool add_ass,
bool use_jinja);
// Returns an example of formatted chat // Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model, std::string common_chat_format_example(
const minja::chat_template & tmpl, bool use_jinja); const llama_chat_template & tmpl, bool use_jinja);
struct llama_chat_templates { struct llama_chat_templates {
minja::chat_template default_template; llama_chat_template default_template;
std::optional<minja::chat_template> tool_use_template; std::optional<llama_chat_template> tool_use_template;
}; };
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
minja::chat_template llama_chat_template_from_model( llama_chat_template llama_chat_template_from_model(
const struct llama_model * model, const struct llama_model * model,
const std::string & chat_template_override = "", const std::string & chat_template_override = "",
bool prefer_tool_use = false); bool prefer_tool_use = false);

View file

@ -74,7 +74,7 @@ std::string llama_tool_call_style_name(llama_tool_call_style style) {
} }
} }
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) { llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template) {
const auto & src = chat_template.source(); const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) { if (src.find("<tool_call>") != std::string::npos) {
@ -399,7 +399,7 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
llama_tool_call_style style, llama_tool_call_style style,
const minja::chat_template & tmpl, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,

View file

@ -41,13 +41,13 @@ struct llama_tool_call_handler {
std::string llama_tool_call_style_name(llama_tool_call_style style); std::string llama_tool_call_style_name(llama_tool_call_style style);
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template); llama_tool_call_style llama_tool_call_style_detect(const llama_chat_template & chat_template);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
llama_tool_call_style style, llama_tool_call_style style,
const minja::chat_template & tmpl, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
const nlohmann::ordered_json & parallel_tool_calls, const nlohmann::ordered_json & parallel_tool_calls,
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,

View file

@ -84,14 +84,6 @@ static void sigint_handler(int signo) {
} }
#endif #endif
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
common_params params; common_params params;
g_params = &params; g_params = &params;
@ -226,7 +218,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode // print chat template example in conversation mode
if (params.conversation_mode) { if (params.conversation_mode) {
if (params.enable_chat_template) { if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str());
} else { } else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
} }
@ -270,10 +262,18 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
{ {
auto prompt = (params.conversation_mode && params.enable_chat_template) auto prompt = (params.conversation_mode && params.enable_chat_template)
// format the system prompt in conversation mode (fallback to default if empty) // format the system prompt in conversation mode (fallback to default if empty)
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
// otherwise use the prompt as is // otherwise use the prompt as is
: params.prompt; : params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@ -766,7 +766,7 @@ int main(int argc, char ** argv) {
} }
if (params.enable_chat_template) { if (params.enable_chat_template) {
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); chat_add_and_format("assistant", assistant_ss.str());
} }
is_interacting = true; is_interacting = true;
LOG("\n"); LOG("\n");
@ -831,7 +831,7 @@ int main(int argc, char ** argv) {
bool format_chat = params.conversation_mode && params.enable_chat_template; bool format_chat = params.conversation_mode && params.enable_chat_template;
std::string user_inp = format_chat std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) ? chat_add_and_format("user", std::move(buffer))
: std::move(buffer); : std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

View file

@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData &
} }
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) { if (use_jinja) {
json messages = json::array(); json messages = json::array();
for (const auto & msg : llama_data.messages) { for (const auto & msg : llama_data.messages) {
@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
} }
// Helper function to apply the chat template and handle errors // Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
if (new_len < 0) { if (new_len < 0) {
printe("failed to apply the chat template\n"); printe("failed to apply the chat template\n");

View file

@ -4389,7 +4389,7 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used // print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
get_chat_templates().default_template.source().c_str(), get_chat_templates().default_template.source().c_str(),
common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -352,7 +352,7 @@ static llama_tokens format_infill(
} }
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const llama_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat; std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
@ -381,7 +381,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
chat.push_back({role, content}); chat.push_back({role, content});
} }
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat; return formatted_chat;
@ -582,7 +582,7 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const struct llama_model * model, const struct llama_model * model,
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
const minja::chat_template & tmpl, const llama_chat_template & tmpl,
llama_tool_call_style tool_call_style, llama_tool_call_style tool_call_style,
bool use_jinja) bool use_jinja)
{ {
@ -673,7 +673,7 @@ static json oaicompat_completion_params_parse(
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} }
} else { } else {
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); llama_params["prompt"] = format_chat(model, tmpl, body.at("messages"));
} }
// Handle "n" field // Handle "n" field

View file

@ -319,9 +319,10 @@ int main(void) {
std::vector<common_chat_msg> chat2; std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant"}; common_chat_msg sys_msg{"system", "You are a helpful assistant"};
auto fmt_sys = [&](std::string tmpl) { auto fmt_sys = [&](std::string tmpl_str) {
auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); minja::chat_template tmpl(tmpl_str, "", "");
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;
}; };
@ -345,9 +346,10 @@ int main(void) {
chat2.push_back({"assistant", "I am assistant"}); chat2.push_back({"assistant", "I am assistant"});
common_chat_msg new_msg{"user", "How are you"}; common_chat_msg new_msg{"user", "How are you"};
auto fmt_single = [&](std::string tmpl) { auto fmt_single = [&](std::string tmpl_str) {
auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true); minja::chat_template tmpl(tmpl_str, "", "");
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;
}; };

View file

@ -311,7 +311,7 @@ static void test_parsing() {
} }
static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { static void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
const minja::chat_template tmpl(read_file(template_file), "<s>", "</s>"); const llama_chat_template tmpl(read_file(template_file), "<s>", "</s>");
auto tool_call_style = llama_tool_call_style_detect(tmpl); auto tool_call_style = llama_tool_call_style_detect(tmpl);
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tool_call_style); assert_equals(expected, tool_call_style);
@ -331,7 +331,7 @@ static void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic); test_tool_call_style("tests/chat/templates/google-gemma-7b-it.jinja", Generic);
} }
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) { static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
@ -356,7 +356,7 @@ static std::string get_message_prompt_delta(const minja::chat_template & tmpl, c
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) { static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
std::cout << "# Testing template: " << template_file << std::endl << std::flush; std::cout << "# Testing template: " << template_file << std::endl << std::flush;
const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token); const llama_chat_template tmpl(read_file(template_file), bos_token, eos_token);
auto tool_call_style = llama_tool_call_style_detect(tmpl); auto tool_call_style = llama_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls"); auto & tool_calls = tool_calling_message.at("tool_calls");