Common:ChatOn+Main:DBUG: Cleanup ChatTmplSimp, RevPrompt Llama2

This is a commit with dbug messages.

ChatApplyTemplateSimple

* wasnt handling unknown template ids properly, this is identified
now and a warning logged, rather than trying to work with len of -1.
Need to change to quit later.

* Also avoid wrapping in a vector, as only a single message can
be tagged wrt chat handshake template.

ReversePrompt

Add support for llama2
This commit is contained in:
HanishKVC 2024-04-20 20:08:00 +05:30
parent 0a8797b28e
commit aac2ee6e9d
2 changed files with 16 additions and 6 deletions

View file

@ -7,18 +7,24 @@
#include "log.h" #include "log.h"
inline std::string llama_chat_apply_template_simple( inline std::string llama_chat_apply_template_simple(
const std::string & tmpl, const std::string &tmpl,
const std::string &role, const std::string &role,
const std::string &content, const std::string &content,
bool add_ass) { bool add_ass) {
llama_chat_message msg = { role.c_str(), content.c_str() }; llama_chat_message msg = { role.c_str(), content.c_str() };
std::vector<llama_chat_message> msgs{ msg }; //std::vector<llama_chat_message> msgs{ msg };
std::vector<char> buf(content.size() * 2); std::vector<char> buf(content.size() * 2);
int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), msgs.data(), msgs.size(), add_ass, buf.data(), buf.size()); int32_t slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size());
LOG_TEELN("DBUG:%s:AA:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size());
if (slen == -1) {
LOG_TEELN("WARN:%s:Unknown template [%s] encounted", __func__, tmpl.c_str());
return "";
}
if ((size_t) slen > buf.size()) { if ((size_t) slen > buf.size()) {
buf.resize(slen); buf.resize(slen);
slen = llama_chat_apply_template(nullptr, tmpl.c_str(), msgs.data(), msgs.size(), add_ass, buf.data(), buf.size()); slen = llama_chat_apply_template(nullptr, tmpl.c_str(), &msg, 1, add_ass, buf.data(), buf.size());
LOG_TEELN("DBUG:%s:BB:%s:LengthNeeded:%d:BufSizeWas:%zu", __func__, role.c_str(), slen, buf.size());
} }
const std::string tagged_msg(buf.data(), slen); const std::string tagged_msg(buf.data(), slen);
@ -28,11 +34,13 @@ inline std::string llama_chat_apply_template_simple(
// return what should be the reverse prompt for the given template id // return what should be the reverse prompt for the given template id
// ie possible end text tag(s) of specified model type's chat query response // ie possible end text tag(s) of specified model type's chat query response
std::vector<std::string> llama_chat_reverse_prompt(std::string &template_id) { inline std::vector<std::string> llama_chat_reverse_prompt(std::string &template_id) {
std::vector<std::string> rends; std::vector<std::string> rends;
if (template_id == "chatml") { if (template_id == "chatml") {
rends.push_back("<|im_start|>user\n"); rends.push_back("<|im_start|>user\n");
} else if (template_id == "llama2") {
rends.push_back("</s>");
} else if (template_id == "llama3") { } else if (template_id == "llama3") {
rends.push_back("<|eot_id|>"); rends.push_back("<|eot_id|>");
} }

View file

@ -258,7 +258,9 @@ int main(int argc, char ** argv) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
} }
if (params.chaton) { if (params.chaton) {
LOG_TEELN("DBUG:%s:AA:%s", __func__, params.prompt.c_str());
params.prompt = llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, false); params.prompt = llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, false);
LOG_TEELN("DBUG:%s:BB:%s", __func__, params.prompt.c_str());
} }
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else { } else {
@ -372,7 +374,7 @@ int main(int argc, char ** argv) {
params.interactive_first = true; params.interactive_first = true;
std::vector<std::string> resp_ends = llama_chat_reverse_prompt(params.chaton_template_id); std::vector<std::string> resp_ends = llama_chat_reverse_prompt(params.chaton_template_id);
if (resp_ends.size() == 0) { if (resp_ends.size() == 0) {
LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatType:%s", __func__, params.chaton_template_id.c_str()); LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatTemplateType:%s", __func__, params.chaton_template_id.c_str());
exit(1); exit(1);
} }
for (size_t i = 0; i < resp_ends.size(); i++) for (size_t i = 0; i < resp_ends.size(); i++)