diff --git a/common/chaton.hpp b/common/chaton.hpp index 62e82d658..da2834668 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -7,18 +7,24 @@ #include "log.h" inline std::string llama_chat_apply_template_simple( - const std::string & tmpl, + const std::string &tmpl, const std::string &role, const std::string &content, bool add_ass) { llama_chat_message msg = { role.c_str(), content.c_str() }; - std::vector msgs{ msg }; + //std::vector msgs{ msg }; std::vector buf(content.size() * 2); - int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), msgs.data(), msgs.size(), add_ass, buf.data(), buf.size()); + int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); + LOG_TEELN("DBUG:%s:AA:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); + if (slen == -1) { + LOG_TEELN("WARN:%s:Unknown template [%s] encounted", __func__, tmpl.c_str()); + return ""; + } if ((size_t) slen > buf.size()) { buf.resize(slen); - slen = llama_chat_apply_template(nullptr, tmpl.c_str(), msgs.data(), msgs.size(), add_ass, buf.data(), buf.size()); + slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size()); + LOG_TEELN("DBUG:%s:BB:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size()); } const std::string tagged_msg(buf.data(), slen); @@ -28,11 +34,13 @@ inline std::string llama_chat_apply_template_simple( // return what should be the reverse prompt for the given template id // ie possible end text tag(s) of specified model type's chat query response -std::vector llama_chat_reverse_prompt(std::string &template_id) { +inline std::vector llama_chat_reverse_prompt(std::string &template_id) { std::vector rends; if (template_id == "chatml") { rends.push_back("<|im_start|>user\n"); + } else if (template_id == "llama2") { + rends.push_back(""); } else if (template_id == "llama3") { rends.push_back("<|eot_id|>"); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a073a7bfd..e04b26d36 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -258,7 +258,9 @@ int main(int argc, char ** argv) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } if (params.chaton) { + LOG_TEELN("DBUG:%s:AA:%s", __func__, params.prompt.c_str()); params.prompt = llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, false); + LOG_TEELN("DBUG:%s:BB:%s", __func__, params.prompt.c_str()); } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else { @@ -372,7 +374,7 @@ int main(int argc, char ** argv) { params.interactive_first = true; std::vector resp_ends = llama_chat_reverse_prompt(params.chaton_template_id); if (resp_ends.size() == 0) { - LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatType:%s", __func__, params.chaton_template_id.c_str()); + LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatTemplateType:%s", __func__, params.chaton_template_id.c_str()); exit(1); } for (size_t i = 0; i < resp_ends.size(); i++)