ChatON+:Multi4Single: applyGlobalIfAny flag wrt templating api

Given that now the multi chat templating logic itself is used to
apply chat templating/tagging to a single chat message, so give
flexibility of deciding whether global tags if any should be
applied or not wrt the core tagging logic.

examples/main inturn updated to not apply global tags if any wrt
the system message. Also the user messages already dont apply
global tags if any, as its currently implemented to build on the
existing in-prefix/suffix and anitprompt flow.
This commit is contained in:
HanishKVC 2024-05-14 00:58:16 +05:30
parent 8165bd4035
commit 3fcaf19967
2 changed files with 16 additions and 9 deletions

View file

@ -531,6 +531,7 @@ inline bool chaton_tmpl_apply_ex(
const std::string &tmpl, const std::string &tmpl,
const std::vector<const llama_chat_message *> &msgs, const std::vector<const llama_chat_message *> &msgs,
bool alertAssistantAtEnd, bool alertAssistantAtEnd,
bool applyGlobalIfAny,
std::string &tagged, std::string &tagged,
std::string &types, std::string &types,
std::vector<int32_t> &lens std::vector<int32_t> &lens
@ -539,8 +540,10 @@ inline bool chaton_tmpl_apply_ex(
return false; return false;
} }
ChatParts cp = {}; ChatParts cp = {};
if (applyGlobalIfAny) {
std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN});
cp.add_part(ChatParts::S, globalBegin); cp.add_part(ChatParts::S, globalBegin);
}
int cntSystem = 0; int cntSystem = 0;
int cntUser = 0; int cntUser = 0;
int cntOthers = 0; int cntOthers = 0;
@ -590,8 +593,10 @@ inline bool chaton_tmpl_apply_ex(
auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX});
cp.add_part(ChatParts::S, assistantBeginPrefix); cp.add_part(ChatParts::S, assistantBeginPrefix);
} }
if (applyGlobalIfAny) {
auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END});
cp.add_part(ChatParts::S, globalEnd); cp.add_part(ChatParts::S, globalEnd);
}
cp.dump(); cp.dump();
tagged = cp.str(); tagged = cp.str();
LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str()); LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str());
@ -608,11 +613,12 @@ inline int32_t chaton_tmpl_apply(
const std::string &tmpl, const std::string &tmpl,
const std::vector<const llama_chat_message *> &msgs, const std::vector<const llama_chat_message *> &msgs,
bool alertAssistantAtEnd, bool alertAssistantAtEnd,
bool applyGlobalIfAny,
std::string &tagged std::string &tagged
) { ) {
std::string types; std::string types;
std::vector<int32_t> lens; std::vector<int32_t> lens;
if (!chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, tagged, types, lens)) { if (!chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens)) {
return -1; return -1;
} }
return tagged.size(); return tagged.size();
@ -629,12 +635,13 @@ inline size_t chaton_tmpl_apply_single(
const std::string &role, const std::string &role,
const std::string &content, const std::string &content,
bool alertAssistantAtEnd, bool alertAssistantAtEnd,
bool applyGlobalIfAny,
std::string &tagged std::string &tagged
) { ) {
std::string types; std::string types;
std::vector<int32_t> lens; std::vector<int32_t> lens;
llama_chat_message cm {role.c_str(), content.c_str()}; llama_chat_message cm {role.c_str(), content.c_str()};
if (!chaton_tmpl_apply_ex(tmpl, {&cm}, alertAssistantAtEnd, tagged, types, lens)) { if (!chaton_tmpl_apply_ex(tmpl, {&cm}, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens)) {
return -1; return -1;
} }
return tagged.size(); return tagged.size();
@ -669,7 +676,7 @@ inline int32_t chaton_tmpl_apply_capi(
vMsgs.push_back(&msgs[i]); vMsgs.push_back(&msgs[i]);
} }
std::string taggedMsgs; std::string taggedMsgs;
int32_t taggedLength = chaton_tmpl_apply(tmpl, vMsgs, alertAssistantAtEnd, taggedMsgs); int32_t taggedLength = chaton_tmpl_apply(tmpl, vMsgs, alertAssistantAtEnd, true, taggedMsgs);
if (taggedLength < 0) { if (taggedLength < 0) {
return taggedLength; return taggedLength;
} }
@ -714,7 +721,7 @@ inline int32_t chaton_tmpl_apply_ex_capi(
std::string taggedMsgs; std::string taggedMsgs;
std::string types; std::string types;
std::vector<int32_t> lens; std::vector<int32_t> lens;
if (!chaton_tmpl_apply_ex(tmpl, vMsgs, alertAssistantAtEnd, taggedMsgs, types, lens)) { if (!chaton_tmpl_apply_ex(tmpl, vMsgs, alertAssistantAtEnd, true, taggedMsgs, types, lens)) {
return -1; return -1;
} }
int32_t taggedLength = taggedMsgs.size(); int32_t taggedLength = taggedMsgs.size();

View file

@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
} }
if (params.chaton) { if (params.chaton) {
chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt, params.prompt); chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt, false, false, params.prompt);
} }
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else { } else {