From ec6212ee64ff84a59cdb8d88998d32e02ca4818c Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Wed, 20 Nov 2024 17:06:56 +0500 Subject: [PATCH] Reverted to a simple solution withing server only. --- common/common.cpp | 22 +++++++--------------- common/common.h | 6 +----- examples/main/main.cpp | 2 +- examples/server/server.cpp | 8 ++++---- examples/server/utils.hpp | 11 +++++++++-- examples/simple-chat/simple-chat.cpp | 6 +++--- include/llama.h | 2 -- src/llama.cpp | 18 +----------------- tests/test-chat-template.cpp | 4 +--- 9 files changed, 27 insertions(+), 52 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d7d25470b..d314523db 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1559,14 +1559,12 @@ std::string common_detokenize(llama_context * ctx, const std::vector= 0; } std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, - const std::string & prefix, - const std::string & suffix, const std::vector & msgs, bool add_ass) { int alloc_size = 0; @@ -1578,12 +1576,10 @@ std::string common_chat_apply_template(const struct llama_model * model, } const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - const char * ptr_prefix = prefix.empty() ? nullptr : prefix.c_str(); - const char * ptr_suffix = suffix.empty() ? nullptr : suffix.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, ptr_prefix, ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1593,7 +1589,7 @@ std::string common_chat_apply_template(const struct llama_model * model, throw std::runtime_error("this custom template is not supported"); } else { // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", nullptr, nullptr, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); fallback = true; } } @@ -1604,8 +1600,6 @@ std::string common_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template( fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, - fallback ? nullptr : ptr_prefix, - fallback ? nullptr : ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } @@ -1619,9 +1613,7 @@ std::string common_chat_format_single(const struct llama_model * model, const common_chat_msg & new_msg, bool add_ass) { std::ostringstream ss; - const std::string prefix; - const std::string suffix; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, prefix, suffix, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1629,21 +1621,21 @@ std::string common_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, prefix, suffix, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl, const std::string & prefix, const std::string & suffix) { + const std::string & tmpl) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return common_chat_apply_template(model, tmpl, prefix, suffix, msgs, true); + return common_chat_apply_template(model, tmpl, msgs, true); } // diff --git a/common/common.h b/common/common.h index ad6af3613..7977cc7a9 100644 --- a/common/common.h +++ b/common/common.h @@ -523,8 +523,6 @@ bool common_chat_verify_template(const std::string & tmpl); // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template(const struct llama_model * model, const std::string & tmpl, - const std::string & prefix, - const std::string & suffix, const std::vector & chat, bool add_ass); @@ -537,9 +535,7 @@ std::string common_chat_format_single(const struct llama_model * model, // Returns an example of formatted chat std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl, - const std::string & prefix, - const std::string & suffix); + const std::string & tmpl); // // KV cache utils diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fe5dc6515..7c4ce4be2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -202,7 +202,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template, params.input_suffix, params.input_suffix).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1d68cd644..c0f8c7e73 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -667,7 +667,7 @@ struct server_context { if (res >= 0) { llama_chat_message chat[] = {{"user", "test"}}; std::string tmpl = std::string(model_template.data(), model_template.size()); - int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0); + int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); return chat_res > 0; } return false; @@ -3229,11 +3229,11 @@ int main(int argc, char ** argv) { } } else if (!params.input_prefix.empty() || !params.input_suffix.empty()) { LOG_WRN("%s: Prefix and suffix are not used because a chat template is defined.\n", __func__); + } else { + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str()); } - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template, params.input_prefix, params.input_suffix).c_str()); - ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 09f3a0c85..d0d9775cf 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -302,6 +302,7 @@ static llama_tokens format_infill( // Format given chat. If tmpl is empty, we take the template from model metadata inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::string & prefix, const std::string & suffix, const std::vector & messages) { std::vector chat; + std::string formatted_chat; for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; @@ -325,10 +326,16 @@ inline std::string format_chat(const struct llama_model * model, const std::stri throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - chat.push_back({role, content}); + if (tmpl == "custom") { + // simple format using prefix and suffix + if (role == "user") formatted_chat += prefix + content + suffix; + else formatted_chat += content; + } else { + chat.push_back({role, content}); + } } - const auto formatted_chat = common_chat_apply_template(model, tmpl, prefix, suffix, chat, true); + if (tmpl != "custom") formatted_chat = common_chat_apply_template(model, tmpl, chat, true); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index d85a385ce..5f9973163 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -158,10 +158,10 @@ int main(int argc, char ** argv) { // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); - int new_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); if (new_len > (int)formatted.size()) { formatted.resize(new_len); - new_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); } if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); @@ -178,7 +178,7 @@ int main(int argc, char ** argv) { // add the response to the messages messages.push_back({"assistant", strdup(response.c_str())}); - prev_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), false, nullptr, 0); + prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); if (prev_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; diff --git a/include/llama.h b/include/llama.h index 0483b527e..90791d5f5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -981,8 +981,6 @@ extern "C" { LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * tmpl, - const char * prefix, - const char * suffix, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, diff --git a/src/llama.cpp b/src/llama.cpp index 789c7578d..c51b36e66 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21820,8 +21820,6 @@ int32_t llama_detokenize( // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. static int32_t llama_chat_apply_template_internal( const std::string & tmpl, - const std::string & prefix, - const std::string & suffix, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 @@ -22098,16 +22096,6 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>\n"; } - } else if (tmpl == "custom") { - // a custom template using only prefix and suffix - for (auto message : chat) { - std::string role(message->role); - if (role == "user") { - ss << prefix << message->content << suffix; - } else { - ss << message->content; - } - } } else { // template not supported return -1; @@ -22119,8 +22107,6 @@ static int32_t llama_chat_apply_template_internal( int32_t llama_chat_apply_template( const struct llama_model * model, const char * tmpl, - const char * prefix, - const char * suffix, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, @@ -22149,9 +22135,7 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - std::string prefix_chat = (prefix == nullptr ? "" : prefix); - std::string suffix_chat = (suffix == nullptr ? "" : suffix); - int32_t res = llama_chat_apply_template_internal(curr_tmpl, prefix_chat, suffix_chat, chat_vec, formatted_chat, add_ass); + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 51121533c..03e897e66 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -118,7 +118,7 @@ int main(void) { int32_t res; // test invalid chat template - res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", nullptr, nullptr, conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); for (size_t i = 0; i < templates.size(); i++) { @@ -128,8 +128,6 @@ int main(void) { res = llama_chat_apply_template( nullptr, custom_template.c_str(), - nullptr, - nullptr, conversation, message_count, true,