Add --jinja and --chat-template-file flags

This commit is contained in:
ochafik 2024-12-30 03:40:34 +00:00
parent abd274a48f
commit e5113e8d74
12 changed files with 289 additions and 50 deletions

View file

@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.hpp
common.cpp
common.h
console.cpp
@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
json.hpp
log.cpp
log.h
minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp

View file

@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
if (!common_chat_verify_template(value)) {
if (!common_chat_verify_template(value, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
value.c_str()
"error: the supplied chat template is not supported: %s%s\n",
value.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
));
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
"set custom jinja chat template file (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::string chat_template;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(chat_template)
);
if (!common_chat_verify_template(chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
value.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
));
}
params.chat_template = chat_template;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -1576,13 +1576,13 @@ std::vector<llama_token> common_tokenize(
return result;
}
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
return piece;
}
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
return _common_token_to_piece(llama_get_model(ctx), token, special);
}
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
// Chat template utils
//
bool common_chat_verify_template(const std::string & tmpl) {
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model,
return common_chat_apply_template(model, tmpl, msgs, true);
}
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
std::string default_template_src = chat_template_override;
std::string tool_use_template_src = chat_template_override;
if (chat_template_override.empty()) {
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!tool_use_template_src.empty()) {
default_template_src = tool_use_template_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
}
}
return {
.default_template = { default_template_src, bos_token, eos_token },
.tool_use_template = tool_use_template_src.empty() ? std::nullopt
: std::optional<minja::chat_template>({ tool_use_template_src, bos_token, eos_token }),
};
}
//
// KV cache utils
//

View file

@ -3,6 +3,7 @@
#pragma once
#include "llama.h"
#include "chat-template.hpp"
#include <string>
#include <vector>
@ -324,6 +325,7 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
std::vector<std::string> api_keys;
@ -571,8 +573,8 @@ struct common_chat_msg {
std::string content;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);
// Check if the template is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model,
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
struct llama_chat_templates {
minja::chat_template default_template;
std::optional<minja::chat_template> tool_use_template;
};
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
//
// KV cache utils
//