Reverted to a simple solution withing server only.

This commit is contained in:
MaggotHATE 2024-11-20 17:06:56 +05:00
parent fc050381ef
commit ec6212ee64
9 changed files with 27 additions and 52 deletions

View file

@ -1559,14 +1559,12 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
bool common_chat_verify_template(const std::string & tmpl) { bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0); int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0; return res >= 0;
} }
std::string common_chat_apply_template(const struct llama_model * model, std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::string & prefix,
const std::string & suffix,
const std::vector<common_chat_msg> & msgs, const std::vector<common_chat_msg> & msgs,
bool add_ass) { bool add_ass) {
int alloc_size = 0; int alloc_size = 0;
@ -1578,12 +1576,10 @@ std::string common_chat_apply_template(const struct llama_model * model,
} }
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
const char * ptr_prefix = prefix.empty() ? nullptr : prefix.c_str();
const char * ptr_suffix = suffix.empty() ? nullptr : suffix.c_str();
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, ptr_prefix, ptr_suffix, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1593,7 +1589,7 @@ std::string common_chat_apply_template(const struct llama_model * model,
throw std::runtime_error("this custom template is not supported"); throw std::runtime_error("this custom template is not supported");
} else { } else {
// If the built-in template is not supported, we default to chatml // If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", nullptr, nullptr, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true; fallback = true;
} }
} }
@ -1604,8 +1600,6 @@ std::string common_chat_apply_template(const struct llama_model * model,
res = llama_chat_apply_template( res = llama_chat_apply_template(
fallback ? nullptr : model, fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl, fallback ? "chatml" : ptr_tmpl,
fallback ? nullptr : ptr_prefix,
fallback ? nullptr : ptr_suffix,
chat.data(), chat.size(), add_ass, buf.data(), buf.size()); chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }
@ -1619,9 +1613,7 @@ std::string common_chat_format_single(const struct llama_model * model,
const common_chat_msg & new_msg, const common_chat_msg & new_msg,
bool add_ass) { bool add_ass) {
std::ostringstream ss; std::ostringstream ss;
const std::string prefix; auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
const std::string suffix;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, prefix, suffix, past_msg, false);
std::vector<common_chat_msg> chat_new(past_msg); std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version // if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1629,21 +1621,21 @@ std::string common_chat_format_single(const struct llama_model * model,
}; };
// format chat with new_msg // format chat with new_msg
chat_new.push_back(new_msg); chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, prefix, suffix, chat_new, add_ass); auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
// get the diff part // get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str(); return ss.str();
} }
std::string common_chat_format_example(const struct llama_model * model, std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl, const std::string & prefix, const std::string & suffix) { const std::string & tmpl) {
std::vector<common_chat_msg> msgs = { std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"}, {"system", "You are a helpful assistant"},
{"user", "Hello"}, {"user", "Hello"},
{"assistant", "Hi there"}, {"assistant", "Hi there"},
{"user", "How are you?"}, {"user", "How are you?"},
}; };
return common_chat_apply_template(model, tmpl, prefix, suffix, msgs, true); return common_chat_apply_template(model, tmpl, msgs, true);
} }
// //

View file

@ -523,8 +523,6 @@ bool common_chat_verify_template(const std::string & tmpl);
// If the custom "tmpl" is not supported, we throw an error // If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model, std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::string & prefix,
const std::string & suffix,
const std::vector<common_chat_msg> & chat, const std::vector<common_chat_msg> & chat,
bool add_ass); bool add_ass);
@ -537,9 +535,7 @@ std::string common_chat_format_single(const struct llama_model * model,
// Returns an example of formatted chat // Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model, std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl);
const std::string & prefix,
const std::string & suffix);
// //
// KV cache utils // KV cache utils

View file

@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode // print chat template example in conversation mode
if (params.conversation) { if (params.conversation) {
if (params.enable_chat_template) { if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template, params.input_suffix, params.input_suffix).c_str()); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
} else { } else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
} }

View file

@ -667,7 +667,7 @@ struct server_context {
if (res >= 0) { if (res >= 0) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
std::string tmpl = std::string(model_template.data(), model_template.size()); std::string tmpl = std::string(model_template.data(), model_template.size());
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), nullptr, nullptr, chat, 1, true, nullptr, 0); int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
return chat_res > 0; return chat_res > 0;
} }
return false; return false;
@ -3229,11 +3229,11 @@ int main(int argc, char ** argv) {
} }
} else if (!params.input_prefix.empty() || !params.input_suffix.empty()) { } else if (!params.input_prefix.empty() || !params.input_suffix.empty()) {
LOG_WRN("%s: Prefix and suffix are not used because a chat template is defined.\n", __func__); LOG_WRN("%s: Prefix and suffix are not used because a chat template is defined.\n", __func__);
} else {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
} }
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template, params.input_prefix, params.input_suffix).c_str());
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -302,6 +302,7 @@ static llama_tokens format_infill(
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::string & prefix, const std::string & suffix, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::string & prefix, const std::string & suffix, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat; std::vector<common_chat_msg> chat;
std::string formatted_chat;
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i]; const auto & curr_msg = messages[i];
@ -325,10 +326,16 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
} }
chat.push_back({role, content}); if (tmpl == "custom") {
// simple format using prefix and suffix
if (role == "user") formatted_chat += prefix + content + suffix;
else formatted_chat += content;
} else {
chat.push_back({role, content});
}
} }
const auto formatted_chat = common_chat_apply_template(model, tmpl, prefix, suffix, chat, true); if (tmpl != "custom") formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat; return formatted_chat;

View file

@ -158,10 +158,10 @@ int main(int argc, char ** argv) {
// add the user input to the message list and format it // add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())}); messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) { if (new_len > (int)formatted.size()) {
formatted.resize(new_len); formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
} }
if (new_len < 0) { if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
@ -178,7 +178,7 @@ int main(int argc, char ** argv) {
// add the response to the messages // add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())}); messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, nullptr, nullptr, messages.data(), messages.size(), false, nullptr, 0); prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
if (prev_len < 0) { if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
return 1; return 1;

View file

@ -981,8 +981,6 @@ extern "C" {
LLAMA_API int32_t llama_chat_apply_template( LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model, const struct llama_model * model,
const char * tmpl, const char * tmpl,
const char * prefix,
const char * suffix,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,

View file

@ -21820,8 +21820,6 @@ int32_t llama_detokenize(
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. // This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
static int32_t llama_chat_apply_template_internal( static int32_t llama_chat_apply_template_internal(
const std::string & tmpl, const std::string & tmpl,
const std::string & prefix,
const std::string & suffix,
const std::vector<const llama_chat_message *> & chat, const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) { std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
@ -22098,16 +22096,6 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) { if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n"; ss << "<|start_of_role|>assistant<|end_of_role|>\n";
} }
} else if (tmpl == "custom") {
// a custom template using only prefix and suffix
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << prefix << message->content << suffix;
} else {
ss << message->content;
}
}
} else { } else {
// template not supported // template not supported
return -1; return -1;
@ -22119,8 +22107,6 @@ static int32_t llama_chat_apply_template_internal(
int32_t llama_chat_apply_template( int32_t llama_chat_apply_template(
const struct llama_model * model, const struct llama_model * model,
const char * tmpl, const char * tmpl,
const char * prefix,
const char * suffix,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,
@ -22149,9 +22135,7 @@ int32_t llama_chat_apply_template(
} }
std::string formatted_chat; std::string formatted_chat;
std::string prefix_chat = (prefix == nullptr ? "" : prefix); int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
std::string suffix_chat = (suffix == nullptr ? "" : suffix);
int32_t res = llama_chat_apply_template_internal(curr_tmpl, prefix_chat, suffix_chat, chat_vec, formatted_chat, add_ass);
if (res < 0) { if (res < 0) {
return res; return res;
} }

View file

@ -118,7 +118,7 @@ int main(void) {
int32_t res; int32_t res;
// test invalid chat template // test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", nullptr, nullptr, conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0); assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) { for (size_t i = 0; i < templates.size(); i++) {
@ -128,8 +128,6 @@ int main(void) {
res = llama_chat_apply_template( res = llama_chat_apply_template(
nullptr, nullptr,
custom_template.c_str(), custom_template.c_str(),
nullptr,
nullptr,
conversation, conversation,
message_count, message_count,
true, true,