Initial "prefix+suffix" chat template
* llama, common, server * main handles prefix and suffix separately, only adjust for example display
This commit is contained in:
parent
ad21c9e1f1
commit
53e0215053
8 changed files with 59 additions and 22 deletions
|
@ -850,17 +850,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
"string to prefix user inputs with (default: empty)",
|
"string to prefix user inputs with (default: empty)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.input_prefix = value;
|
params.input_prefix = value;
|
||||||
params.enable_chat_template = false;
|
// params.enable_chat_template = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--in-suffix"}, "STRING",
|
{"--in-suffix"}, "STRING",
|
||||||
"string to suffix after user inputs with (default: empty)",
|
"string to suffix after user inputs with (default: empty)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.input_suffix = value;
|
params.input_suffix = value;
|
||||||
params.enable_chat_template = false;
|
// params.enable_chat_template = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--no-warmup"},
|
{"--no-warmup"},
|
||||||
"skip warming up the model with an empty run",
|
"skip warming up the model with an empty run",
|
||||||
|
|
|
@ -1559,12 +1559,14 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
|
||||||
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl) {
|
bool common_chat_verify_template(const std::string & tmpl) {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0);
|
||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
|
const std::string & prefix,
|
||||||
|
const std::string & suffix,
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass) {
|
bool add_ass) {
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
|
@ -1576,10 +1578,12 @@ std::string common_chat_apply_template(const struct llama_model * model,
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||||
|
const char * ptr_prefix = prefix.empty() ? nullptr : prefix.c_str();
|
||||||
|
const char * ptr_suffix = suffix.empty() ? nullptr : suffix.c_str();
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
int32_t res = llama_chat_apply_template(model, ptr_tmpl, ptr_prefix, ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
|
@ -1589,7 +1593,7 @@ std::string common_chat_apply_template(const struct llama_model * model,
|
||||||
throw std::runtime_error("this custom template is not supported");
|
throw std::runtime_error("this custom template is not supported");
|
||||||
} else {
|
} else {
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
res = llama_chat_apply_template(nullptr, "chatml", nullptr, nullptr, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
fallback = true;
|
fallback = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1600,6 +1604,8 @@ std::string common_chat_apply_template(const struct llama_model * model,
|
||||||
res = llama_chat_apply_template(
|
res = llama_chat_apply_template(
|
||||||
fallback ? nullptr : model,
|
fallback ? nullptr : model,
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
fallback ? "chatml" : ptr_tmpl,
|
||||||
|
fallback ? nullptr : ptr_prefix,
|
||||||
|
fallback ? nullptr : ptr_suffix,
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1613,7 +1619,9 @@ std::string common_chat_format_single(const struct llama_model * model,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass) {
|
bool add_ass) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
|
const std::string prefix;
|
||||||
|
const std::string suffix;
|
||||||
|
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, prefix, suffix, past_msg, false);
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
std::vector<common_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
|
@ -1621,21 +1629,21 @@ std::string common_chat_format_single(const struct llama_model * model,
|
||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
|
auto fmt_new_msg = common_chat_apply_template(model, tmpl, prefix, suffix, chat_new, add_ass);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
std::string common_chat_format_example(const struct llama_model * model,
|
||||||
const std::string & tmpl) {
|
const std::string & tmpl, const std::string & prefix, const std::string & suffix) {
|
||||||
std::vector<common_chat_msg> msgs = {
|
std::vector<common_chat_msg> msgs = {
|
||||||
{"system", "You are a helpful assistant"},
|
{"system", "You are a helpful assistant"},
|
||||||
{"user", "Hello"},
|
{"user", "Hello"},
|
||||||
{"assistant", "Hi there"},
|
{"assistant", "Hi there"},
|
||||||
{"user", "How are you?"},
|
{"user", "How are you?"},
|
||||||
};
|
};
|
||||||
return common_chat_apply_template(model, tmpl, msgs, true);
|
return common_chat_apply_template(model, tmpl, prefix, suffix, msgs, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -523,6 +523,8 @@ bool common_chat_verify_template(const std::string & tmpl);
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
// If the custom "tmpl" is not supported, we throw an error
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
|
const std::string & prefix,
|
||||||
|
const std::string & suffix,
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass);
|
bool add_ass);
|
||||||
|
|
||||||
|
@ -535,7 +537,9 @@ std::string common_chat_format_single(const struct llama_model * model,
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
std::string common_chat_format_example(const struct llama_model * model,
|
||||||
const std::string & tmpl);
|
const std::string & tmpl,
|
||||||
|
const std::string & prefix,
|
||||||
|
const std::string & suffix);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
|
|
|
@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation) {
|
if (params.conversation) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template, params.input_suffix, params.input_suffix).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
|
@ -667,7 +667,7 @@ struct server_context {
|
||||||
if (res >= 0) {
|
if (res >= 0) {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
std::string tmpl = std::string(model_template.data(), model_template.size());
|
std::string tmpl = std::string(model_template.data(), model_template.size());
|
||||||
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0);
|
||||||
return chat_res > 0;
|
return chat_res > 0;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -2829,7 +2829,7 @@ int main(int argc, char ** argv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.input_prefix, params.input_suffix);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
@ -3220,14 +3220,19 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||||
if (params.chat_template.empty()) {
|
if (params.chat_template.empty()) {
|
||||||
if (!ctx_server.validate_model_chat_template()) {
|
if (!params.input_prefix.empty() || !params.input_suffix.empty()) {
|
||||||
|
LOG_WRN("%s: Prefix and suffix are used instead of a chat template. This may cause the model to output suboptimal responses\n", __func__);
|
||||||
|
params.chat_template = "custom";
|
||||||
|
} else if (!ctx_server.validate_model_chat_template()) {
|
||||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||||
params.chat_template = "chatml";
|
params.chat_template = "chatml";
|
||||||
}
|
}
|
||||||
|
} else if (!params.input_prefix.empty() || !params.input_suffix.empty()) {
|
||||||
|
LOG_WRN("%s: Prefix and suffix are not used because a chat template is defined.\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template, params.input_prefix, params.input_suffix).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
|
|
@ -300,7 +300,7 @@ static llama_tokens format_infill(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::string & prefix, const std::string & suffix, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
|
@ -328,7 +328,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
const auto formatted_chat = common_chat_apply_template(model, tmpl, prefix, suffix, chat, true);
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
|
@ -597,13 +597,15 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
const std::string & chat_template) {
|
const std::string & chat_template,
|
||||||
|
const std::string & input_prefix,
|
||||||
|
const std::string & input_suffix) {
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
llama_params["prompt"] = format_chat(model, chat_template, input_prefix, input_suffix, body.at("messages"));
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
|
|
@ -981,6 +981,8 @@ extern "C" {
|
||||||
LLAMA_API int32_t llama_chat_apply_template(
|
LLAMA_API int32_t llama_chat_apply_template(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const char * tmpl,
|
const char * tmpl,
|
||||||
|
const char * prefix,
|
||||||
|
const char * suffix,
|
||||||
const struct llama_chat_message * chat,
|
const struct llama_chat_message * chat,
|
||||||
size_t n_msg,
|
size_t n_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
|
|
|
@ -21820,6 +21820,8 @@ int32_t llama_detokenize(
|
||||||
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
||||||
static int32_t llama_chat_apply_template_internal(
|
static int32_t llama_chat_apply_template_internal(
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
|
const std::string & prefix,
|
||||||
|
const std::string & suffix,
|
||||||
const std::vector<const llama_chat_message *> & chat,
|
const std::vector<const llama_chat_message *> & chat,
|
||||||
std::string & dest, bool add_ass) {
|
std::string & dest, bool add_ass) {
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
|
@ -22096,6 +22098,16 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == "custom") {
|
||||||
|
// a custom template using only prefix and suffix
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "user") {
|
||||||
|
ss << prefix << message->content << suffix;
|
||||||
|
} else {
|
||||||
|
ss << message->content;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -22107,6 +22119,8 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
int32_t llama_chat_apply_template(
|
int32_t llama_chat_apply_template(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const char * tmpl,
|
const char * tmpl,
|
||||||
|
const char * prefix,
|
||||||
|
const char * suffix,
|
||||||
const struct llama_chat_message * chat,
|
const struct llama_chat_message * chat,
|
||||||
size_t n_msg,
|
size_t n_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
|
@ -22135,7 +22149,9 @@ int32_t llama_chat_apply_template(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat;
|
std::string formatted_chat;
|
||||||
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
std::string prefix_chat = (prefix == nullptr ? "" : prefix);
|
||||||
|
std::string suffix_chat = (suffix == nullptr ? "" : suffix);
|
||||||
|
int32_t res = llama_chat_apply_template_internal(curr_tmpl, prefix_chat, suffix_chat, chat_vec, formatted_chat, add_ass);
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue