From 53e0215053bf6557edfe86630a3bca4035e69543 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Wed, 20 Nov 2024 14:42:03 +0500 Subject: [PATCH] Initial "prefix+suffix" chat template * llama, common, server * main handles prefix and suffix separately, only adjust for example display --- common/arg.cpp | 8 ++++---- common/common.cpp | 22 +++++++++++++++------- common/common.h | 6 +++++- examples/main/main.cpp | 2 +- examples/server/server.cpp | 13 +++++++++---- examples/server/utils.hpp | 10 ++++++---- include/llama.h | 2 ++ src/llama.cpp | 18 +++++++++++++++++- 8 files changed, 59 insertions(+), 22 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 4115b2f75..97fedfcc3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -850,17 +850,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "string to prefix user inputs with (default: empty)", [](common_params & params, const std::string & value) { params.input_prefix = value; - params.enable_chat_template = false; + // params.enable_chat_template = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL})); add_opt(common_arg( {"--in-suffix"}, "STRING", "string to suffix after user inputs with (default: empty)", [](common_params & params, const std::string & value) { params.input_suffix = value; - params.enable_chat_template = false; + // params.enable_chat_template = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL})); add_opt(common_arg( {"--no-warmup"}, "skip warming up the model with an empty run", diff --git a/common/common.cpp b/common/common.cpp index d314523db..d7d25470b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1559,12 +1559,14 @@ std::string common_detokenize(llama_context * ctx, const std::vector= 0; } std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, + const std::string & prefix, + const std::string & suffix, const std::vector & msgs, bool add_ass) { int alloc_size = 0; @@ -1576,10 +1578,12 @@ std::string common_chat_apply_template(const struct llama_model * model, } const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + const char * ptr_prefix = prefix.empty() ? nullptr : prefix.c_str(); + const char * ptr_suffix = suffix.empty() ? nullptr : suffix.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(model, ptr_tmpl, ptr_prefix, ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1589,7 +1593,7 @@ std::string common_chat_apply_template(const struct llama_model * model, throw std::runtime_error("this custom template is not supported"); } else { // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(nullptr, "chatml", nullptr, nullptr, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); fallback = true; } } @@ -1600,6 +1604,8 @@ std::string common_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template( fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, + fallback ? nullptr : ptr_prefix, + fallback ? nullptr : ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } @@ -1613,7 +1619,9 @@ std::string common_chat_format_single(const struct llama_model * model, const common_chat_msg & new_msg, bool add_ass) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); + const std::string prefix; + const std::string suffix; + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, prefix, suffix, past_msg, false); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1621,21 +1629,21 @@ std::string common_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(model, tmpl, prefix, suffix, chat_new, add_ass); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { + const std::string & tmpl, const std::string & prefix, const std::string & suffix) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return common_chat_apply_template(model, tmpl, msgs, true); + return common_chat_apply_template(model, tmpl, prefix, suffix, msgs, true); } // diff --git a/common/common.h b/common/common.h index 7977cc7a9..ad6af3613 100644 --- a/common/common.h +++ b/common/common.h @@ -523,6 +523,8 @@ bool common_chat_verify_template(const std::string & tmpl); // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, + const std::string & prefix, + const std::string & suffix, const std::vector & chat, bool add_ass); @@ -535,7 +537,9 @@ std::string common_chat_format_single(const struct llama_model * model, // Returns an example of formatted chat std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); + const std::string & tmpl, + const std::string & prefix, + const std::string & suffix); // // KV cache utils diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7c4ce4be2..fe5dc6515 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -202,7 +202,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template, params.input_suffix, params.input_suffix).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b8e003be9..8777882db 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -667,7 +667,7 @@ struct server_context { if (res >= 0) { llama_chat_message chat[] = {{"user", "test"}}; std::string tmpl = std::string(model_template.data(), model_template.size()); - int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); + int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0); return chat_res > 0; } return false; @@ -2829,7 +2829,7 @@ int main(int argc, char ** argv) { return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.input_prefix, params.input_suffix); std::vector tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION); ctx_server.queue_results.add_waiting_tasks(tasks); @@ -3220,14 +3220,19 @@ int main(int argc, char ** argv) { // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { + if (!params.input_prefix.empty() || !params.input_suffix.empty()) { + LOG_WRN("%s: Prefix and suffix are used instead of a chat template. This may cause the model to output suboptimal responses\n", __func__); + params.chat_template = "custom"; + } else if (!ctx_server.validate_model_chat_template()) { LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } + } else if (!params.input_prefix.empty() || !params.input_suffix.empty()) { + LOG_WRN("%s: Prefix and suffix are not used because a chat template is defined.\n", __func__); } // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str()); + LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template, params.input_prefix, params.input_suffix).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c47ed3e47..09f3a0c85 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -300,7 +300,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::string & prefix, const std::string & suffix, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -328,7 +328,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.push_back({role, content}); } - const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + const auto formatted_chat = common_chat_apply_template(model, tmpl, prefix, suffix, chat, true); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -597,13 +597,15 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template) { + const std::string & chat_template, + const std::string & input_prefix, + const std::string & input_suffix) { json llama_params; llama_params["__oaicompat"] = true; // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + llama_params["prompt"] = format_chat(model, chat_template, input_prefix, input_suffix, body.at("messages")); // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { diff --git a/include/llama.h b/include/llama.h index 90791d5f5..0483b527e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -981,6 +981,8 @@ extern "C" { LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * tmpl, + const char * prefix, + const char * suffix, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, diff --git a/src/llama.cpp b/src/llama.cpp index c51b36e66..789c7578d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21820,6 +21820,8 @@ int32_t llama_detokenize( // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. static int32_t llama_chat_apply_template_internal( const std::string & tmpl, + const std::string & prefix, + const std::string & suffix, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 @@ -22096,6 +22098,16 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>\n"; } + } else if (tmpl == "custom") { + // a custom template using only prefix and suffix + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << prefix << message->content << suffix; + } else { + ss << message->content; + } + } } else { // template not supported return -1; @@ -22107,6 +22119,8 @@ static int32_t llama_chat_apply_template_internal( int32_t llama_chat_apply_template( const struct llama_model * model, const char * tmpl, + const char * prefix, + const char * suffix, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, @@ -22135,7 +22149,9 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + std::string prefix_chat = (prefix == nullptr ? "" : prefix); + std::string suffix_chat = (suffix == nullptr ? "" : suffix); + int32_t res = llama_chat_apply_template_internal(curr_tmpl, prefix_chat, suffix_chat, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; }