From e983c9d0dede0cf480b46279225b52c15f0c78c8 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 25 Sep 2024 22:02:58 +0100 Subject: [PATCH] `tool-call`: fix llama_chat_apply_template signature / test-chat-template --- common/common.cpp | 14 +++++++------- common/common.h | 4 ++-- examples/server/utils.hpp | 2 +- tests/test-chat-template.cpp | 9 ++++++--- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index bcf49f186..a757faf5f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1521,7 +1521,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, const std::vector & msgs, bool add_ass, bool use_jinja, - const std::string & tools, + const char * tools, const char * bos_token, const char * eos_token) { int alloc_size = 0; @@ -1536,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token); + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); // error: chat template is not supported if (res < 0) { @@ -1546,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, throw std::runtime_error("this custom template is not supported"); } else { // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token); + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); fallback = true; } } @@ -1557,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template( fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token); + chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); } std::string formatted_chat(buf.data(), res); @@ -1570,11 +1570,11 @@ std::string llama_chat_format_single(const struct llama_model * model, const llama_chat_msg & new_msg, bool add_ass, bool use_jinja, - const std::string & tools, + const char * tools, const char * bos_token, const char * eos_token) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token); + auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1582,7 +1582,7 @@ std::string llama_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index a42c675cc..1b5683c00 100644 --- a/common/common.h +++ b/common/common.h @@ -493,7 +493,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, const std::vector & chat, bool add_ass, bool use_jinja = false, - const std::string & tools = "", + const char * tools = nullptr, const char * bos_token = nullptr, const char * eos_token = nullptr); @@ -504,7 +504,7 @@ std::string llama_chat_format_single(const struct llama_model * model, const llama_chat_msg & new_msg, bool add_ass, bool use_jinja = false, - const std::string & tools = "", + const char * tools = nullptr, const char * bos_token = nullptr, const char * eos_token = nullptr); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index a80a1b5dd..f28f7086d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -97,7 +97,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.emplace_back(std::move(msg)); } - const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? "" : tools.dump()); + const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str()); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 114ce5928..68fe6c381 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -27,6 +27,8 @@ int main(void) { {"user", "Another question"}, }; + std::string tools = ""; + std::vector templates { { .name = "teknium/OpenHermes-2.5-Mistral-7B", @@ -160,7 +162,7 @@ int main(void) { int32_t res; // test invalid chat template - res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, "<|im_start|>", "<|im_end|>"); + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>"); assert(res < 0); for (auto use_jinja : std::vector { false, true }) { @@ -182,6 +184,7 @@ int main(void) { formatted_chat.data(), formatted_chat.size(), use_jinja, + tools.empty() ? nullptr : tools.c_str(), tmpl.bos.c_str(), tmpl.eos.c_str() ); @@ -210,7 +213,7 @@ int main(void) { llama_chat_msg sys_msg{"system", "You are a helpful assistant"}; auto fmt_sys = [&](std::string tmpl) { - auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, "<|im_start|>", "<|im_end|>"); + auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, /** tools= */ "", "<|im_start|>", "<|im_end|>"); printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); printf("-------------------------\n"); return output; @@ -229,7 +232,7 @@ int main(void) { llama_chat_msg new_msg{"user", "How are you"}; auto fmt_single = [&](std::string tmpl) { - auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, "<|im_start|>", "<|im_end|>"); + auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>"); printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); printf("-------------------------\n"); return output;