Add --jinja and --chat-template-file flags
This commit is contained in:
parent
abd274a48f
commit
e5113e8d74
12 changed files with 289 additions and 50 deletions
|
@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
|
|||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
|
@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
|
|||
json.hpp
|
||||
log.cpp
|
||||
log.h
|
||||
minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
sampling.cpp
|
||||
|
|
|
@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--jinja"},
|
||||
"use jinja template for chat (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (!common_chat_verify_template(value)) {
|
||||
if (!common_chat_verify_template(value, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s\n"
|
||||
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
|
||||
value.c_str()
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
value.c_str(),
|
||||
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
|
||||
));
|
||||
}
|
||||
params.chat_template = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||
"set custom jinja chat template file (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::ifstream file(value);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
||||
}
|
||||
std::string chat_template;
|
||||
std::copy(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(chat_template)
|
||||
);
|
||||
if (!common_chat_verify_template(chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
value.c_str(),
|
||||
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
|
||||
));
|
||||
}
|
||||
params.chat_template = chat_template;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||
|
|
|
@ -1576,13 +1576,13 @@ std::vector<llama_token> common_tokenize(
|
|||
return result;
|
||||
}
|
||||
|
||||
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
||||
static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
|
@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
|
|||
return piece;
|
||||
}
|
||||
|
||||
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
||||
return _common_token_to_piece(llama_get_model(ctx), token, special);
|
||||
}
|
||||
|
||||
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
|
@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
|
|||
// Chat template utils
|
||||
//
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl) {
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
|
@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model,
|
|||
return common_chat_apply_template(model, tmpl, msgs, true);
|
||||
}
|
||||
|
||||
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||
{
|
||||
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
|
||||
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string tool_use_template_src = chat_template_override;
|
||||
if (chat_template_override.empty()) {
|
||||
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!tool_use_template_src.empty()) {
|
||||
default_template_src = tool_use_template_src;
|
||||
} else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
}
|
||||
}
|
||||
return {
|
||||
.default_template = { default_template_src, bos_token, eos_token },
|
||||
.tool_use_template = tool_use_template_src.empty() ? std::nullopt
|
||||
: std::optional<minja::chat_template>({ tool_use_template_src, bos_token, eos_token }),
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -324,6 +325,7 @@ struct common_params {
|
|||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
@ -571,8 +573,8 @@ struct common_chat_msg {
|
|||
std::string content;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl);
|
||||
// Check if the template is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
|
@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model,
|
|||
std::string common_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl);
|
||||
|
||||
|
||||
struct llama_chat_templates {
|
||||
minja::chat_template default_template;
|
||||
std::optional<minja::chat_template> tool_use_template;
|
||||
};
|
||||
|
||||
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue