Add Jinja template support (#11016)

* Copy minja from 58f0ca6dd7

* Add --jinja and --chat-template-file flags

* Add missing <optional> include

* Avoid print in get_hf_chat_template.py

* No designated initializers yet

* Try and work around msvc++ non-macro max resolution quirk

* Update test_chat_completion.py

* Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template

* Refactor test-chat-template

* Test templates w/ minja

* Fix deprecation

* Add --jinja to llama-run

* Update common_chat_format_example to use minja template wrapper

* Test chat_template in e2e test

* Update utils.py

* Update test_chat_completion.py

* Update run.cpp

* Update arg.cpp

* Refactor common_chat_* functions to accept minja template + use_jinja option

* Attempt to fix linkage of LLAMA_CHATML_TEMPLATE

* Revert LLAMA_CHATML_TEMPLATE refactor

* Normalize newlines in test-chat-templates for windows tests

* Forward decl minja::chat_template to avoid eager json dep

* Flush stdout in chat template before potential crash

* Fix copy elision warning

* Rm unused optional include

* Add missing optional include to server.cpp

* Disable jinja test that has a cryptic windows failure

* minja: fix vigogne (https://github.com/google/minja/pull/22)

* Apply suggestions from code review

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Finish suggested renamings

* Move chat_templates inside server_context + remove mutex

* Update --chat-template-file w/ recent change to --chat-template

* Refactor chat template validation

* Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr)

* Warn against missing eos / bos tokens when jinja template references them

* rename: common_chat_template[s]

* reinstate assert on chat_templates.template_default

* Update minja to b8437df626

* Update minja to https://github.com/google/minja/pull/25

* Update minja from https://github.com/google/minja/pull/27

* rm unused optional header

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Olivier Chafik 2025-01-21 13:18:51 +00:00 committed by GitHub
parent e28245f35f
commit 6171c9d258
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 3563 additions and 133 deletions

View file

@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.hpp
common.cpp
common.h
console.cpp
@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
json.hpp
log.cpp
log.h
minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp

View file

@ -325,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
params.chat_template.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
));
}
return true;
}
@ -1947,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
if (!common_chat_verify_template(value)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
value.c_str()
));
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(
"set custom jinja chat template file (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.chat_template));
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

249
common/chat-template.hpp Normal file
View file

@ -0,0 +1,249 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
namespace minja {
class chat_template {
public:
private:
bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string try_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
// fprintf(stderr, "Error: %s\n", e.what());
return "";
}
}
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::Parser::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;
auto renders_string_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", {
{"code", "print('Hello, World!')"},
}},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
requires_object_arguments_ = renders_object_arguments;
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
supports_system_role_ = try_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;
}
const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
json actual_messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
actual_messages = json::array();
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tools_) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
if (!supports_tools_) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && content != "") {
obj["content"] = content;
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (!supports_tools_ && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", {
{"tool", message.at("name")},
{"content", message.at("content")},
}},
};
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}
if (!message["content"].is_null() && !supports_system_role_) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
actual_messages.push_back(message);
}
flush_sys();
} else {
actual_messages = messages;
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", bos_token_},
{"eos_token", eos_token_},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return template_root_->render(context);
}
};
} // namespace minja

View file

@ -12,6 +12,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
@ -1728,67 +1729,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
// Chat template utils
//
std::string common_get_builtin_chat_template(const struct llama_model * model) {
const char * ptr_tmpl = llama_model_chat_template(model);
return ptr_tmpl == nullptr ? "" : ptr_tmpl;
}
bool common_chat_verify_template(const std::string & tmpl) {
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass) {
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
}
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass) {
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1796,21 +1805,74 @@ std::string common_chat_format_single(const struct llama_model * model,
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl) {
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return common_chat_apply_template(model, tmpl, msgs, true);
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
std::string default_template_src = chat_template_override;
std::string template_tool_use_src = chat_template_override;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
}
}
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
};
}
//

View file

@ -334,6 +334,7 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
std::vector<std::string> api_keys;
@ -603,30 +604,43 @@ struct common_chat_msg {
std::string content;
};
// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass);
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass);
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
//
// KV cache utils

2788
common/minja.hpp Normal file

File diff suppressed because it is too large Load diff