diff --git a/common/chaton.hpp b/common/chaton.hpp index da2834668..91b3d4802 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -6,30 +6,33 @@ #include "llama.h" #include "log.h" -inline std::string llama_chat_apply_template_simple( +// Tag the passed message suitabley as expected by the specified chat handshake template +// and the role. If the specified template is not supported logic will return false. +inline bool llama_chat_apply_template_simple( const std::string &tmpl, const std::string &role, const std::string &content, + std::string &dst, bool add_ass) { llama_chat_message msg = { role.c_str(), content.c_str() }; - //std::vector msgs{ msg }; - std::vector buf(content.size() * 2); + std::vector buf(content.size() * 2); // This may under allot for small messages and over allot for large messages int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); - LOG_TEELN("DBUG:%s:AA:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); if (slen == -1) { - LOG_TEELN("WARN:%s:Unknown template [%s] encounted", __func__, tmpl.c_str()); - return ""; + LOG_TEELN("WARN:%s:Unknown template [%s] requested", __func__, tmpl.c_str()); + dst = ""; + return false; } if ((size_t) slen > buf.size()) { + LOGLN("INFO:%s:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); buf.resize(slen); slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); - LOG_TEELN("DBUG:%s:BB:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); } const std::string tagged_msg(buf.data(), slen); - LOGLN("INFO:%s:%s", __func__, tagged_msg.c_str()); - return tagged_msg; + LOGLN("INFO:%s:%s:%s", __func__, role.c_str(), tagged_msg.c_str()); + dst = tagged_msg; + return true; } // return what should be the reverse prompt for the given template id diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e04b26d36..dfd16670e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -258,9 +258,10 @@ int main(int argc, char ** argv) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } if (params.chaton) { - LOG_TEELN("DBUG:%s:AA:%s", __func__, params.prompt.c_str()); - params.prompt = llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, false); - LOG_TEELN("DBUG:%s:BB:%s", __func__, params.prompt.c_str()); + if (!llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, params.prompt, false)) { + LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "system", params.prompt.c_str()); + exit(2); + } } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else { @@ -897,7 +898,11 @@ int main(int argc, char ** argv) { std::vector line_inp; if (params.chaton) { - std::string f_chat = llama_chat_apply_template_simple(params.chaton_template_id, "user", buffer.c_str(), true); + std::string f_chat; + if (!llama_chat_apply_template_simple(params.chaton_template_id, "user", buffer.c_str(), f_chat, true)) { + LOG_TEELN("ERRR:%s:Wrt:%s:%s:%s", __func__, params.chaton_template_id.c_str(), "user", params.prompt.c_str()); + exit(2); + } line_inp = ::llama_tokenize(ctx, f_chat, false, true); LOG("formatted input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());