From 4dfd10a40d4db55cc3a25572d36269a32f408d84 Mon Sep 17 00:00:00 2001 From: HanishKVC Date: Tue, 14 May 2024 01:49:38 +0530 Subject: [PATCH] ChatON: Move core templating/tagging code into ChatTemplates class However still retain the wrappers, which work with a predefined global instance of ChatTemplates. --- common/chaton.hpp | 191 ++++++++++++++++++++++++++-------------------- 1 file changed, 107 insertions(+), 84 deletions(-) diff --git a/common/chaton.hpp b/common/chaton.hpp index 96d25c045..ba377f2b9 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -409,6 +409,108 @@ public: return got; } + /** + * Given the template standard and a bunch of messages including their roles, this returns + * tagged messages, subPartsTypes string and subPartsLens vector. The returned subParts + * types string and lens vector help identify the parts of the tagged msgs string, + * which relate to passed msgs and added tags. + * + * * a string containing the tagged messages + * * global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end + * * a string where the chars contain info about + * type of sub-strings/parts that make up the tagged messages string. + * * a vector of ints, which give the length of each part in the tagged messages string. + * + * If a combination of system-user messages is passed, then tags between the 1st system and + * the 1st user message, is based on the flags set wrt the corresponding template standard. + * If you dont want this behaviour, pass non 0 values wrt the optional cntSystemMsgCnt and + * cntUserMsgCnt arguments. + */ + bool chaton_tmpl_apply_ex( + const std::string &tmpl, + const std::vector &msgs, + bool alertAssistantAtEnd, + bool applyGlobalIfAny, + std::string &tagged, + std::string &types, + std::vector &lens, + int curSystemMsgCnt = 0, + int curUserMsgCnt = 0 + ) { + if (!tmpl_exists(tmpl)) { + return false; + } + ChatParts cp = {}; + if (applyGlobalIfAny) { + std::string globalBegin = tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); + cp.add_part(ChatParts::S, globalBegin); + } + int cntSystem = curSystemMsgCnt; + int cntUser = curUserMsgCnt; + int cntOthers = 0; + for(const auto msg: msgs) { + auto role = msg->role; + auto content = msg->content; + std::string begin = tmpl_role_getkeys(tmpl, role, {K_BEGIN}); + auto prefix = tmpl_role_getkeys(tmpl, role, {K_PREFIX}); + auto suffix = tmpl_role_getkeys(tmpl, role, {K_SUFFIX}); + auto end = tmpl_role_getkeys(tmpl, role, {K_END}); + if (role == K_SYSTEM) { + cntSystem += 1; + cp.add_part(ChatParts::S, begin); + cp.add_part(ChatParts::S, prefix); + } else if (role == K_USER) { + cntUser += 1; + if ((cntSystem == 1) && (cntUser == 1)) { + if (tmpl_getkey(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN, true)) { + cp.add_part(ChatParts::S, begin); + } + if (tmpl_getkey(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX, true)) { + cp.add_part(ChatParts::S, prefix); + } + } else { + cp.add_part(ChatParts::S, begin); + cp.add_part(ChatParts::S, prefix); + } + } else { + cntOthers += 1; + cp.add_part(ChatParts::S, begin); + cp.add_part(ChatParts::S, prefix); + } + cp.add_part(ChatParts::N, content); + if (role == K_SYSTEM) { + if (cntSystem == 1) { + if (tmpl_getkey(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX, true)) { + cp.add_part(ChatParts::S, suffix); + } + if (tmpl_getkey(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END, true)) { + cp.add_part(ChatParts::S, end); + } + } else { + cp.add_part(ChatParts::S, suffix); + cp.add_part(ChatParts::S, end); + } + } else { + cp.add_part(ChatParts::S, suffix); + cp.add_part(ChatParts::S, end); + } + } + if (alertAssistantAtEnd) { + auto assistantBeginPrefix = tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); + cp.add_part(ChatParts::S, assistantBeginPrefix); + } + if (applyGlobalIfAny) { + auto globalEnd = tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); + cp.add_part(ChatParts::S, globalEnd); + } + cp.dump(); + tagged = cp.str(); + LOGLN("DBUG:CT:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str()); + LOGLN("DBUG:CT:%s:%s:CntSys[%d]:CntUsr[%d]:CntOthers[%d]", __func__, tmpl.c_str(), cntSystem, cntUser, cntOthers); + types = cp.get_partstypes(); + lens = cp.get_partslens(); + return true; + } }; @@ -516,17 +618,10 @@ inline bool chaton_tmpl_getkey_bool(const std::string &tmpl, const std::string & // Given the template standard and a bunch of messages including their roles, this returns -// tagged messages, types string and lens vector. Returned types string and lens vector help -// identify the parts of the tagged msgs string, which relate to passed msgs and added tags. +// the tagged messages as a string. +// global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end // -// * a string containing the tagged messages -// * global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end -// * a string where the chars contain info about -// type of sub-strings/parts that make up the tagged messages string. -// * a vector of ints, which give the length of each part in the tagged messages string. -// -// if a combination of system-user messages is passed, then tags between the system -// and the 1st user message, is based on the flags set wrt the corresponding template standard. +// Additionally also return info about the parts that make up the tagged message. inline bool chaton_tmpl_apply_ex( const std::string &tmpl, const std::vector &msgs, @@ -538,79 +633,7 @@ inline bool chaton_tmpl_apply_ex( int curSystemMsgCnt = 0, int curUserMsgCnt = 0 ) { - if (!chaton_tmpl_exists(tmpl)) { - return false; - } - ChatParts cp = {}; - if (applyGlobalIfAny) { - std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); - cp.add_part(ChatParts::S, globalBegin); - } - int cntSystem = curSystemMsgCnt; - int cntUser = curUserMsgCnt; - int cntOthers = 0; - for(const auto msg: msgs) { - auto role = msg->role; - auto content = msg->content; - std::string begin = chaton_tmpl_role_getkeys(tmpl, role, {K_BEGIN}); - auto prefix = chaton_tmpl_role_getkeys(tmpl, role, {K_PREFIX}); - auto suffix = chaton_tmpl_role_getkeys(tmpl, role, {K_SUFFIX}); - auto end = chaton_tmpl_role_getkeys(tmpl, role, {K_END}); - if (role == K_SYSTEM) { - cntSystem += 1; - cp.add_part(ChatParts::S, begin); - cp.add_part(ChatParts::S, prefix); - } else if (role == K_USER) { - cntUser += 1; - if ((cntSystem == 1) && (cntUser == 1)) { - if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN)) { - cp.add_part(ChatParts::S, begin); - } - if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX)) { - cp.add_part(ChatParts::S, prefix); - } - } else { - cp.add_part(ChatParts::S, begin); - cp.add_part(ChatParts::S, prefix); - } - } else { - cntOthers += 1; - cp.add_part(ChatParts::S, begin); - cp.add_part(ChatParts::S, prefix); - } - cp.add_part(ChatParts::N, content); - if (role == K_SYSTEM) { - if (cntSystem == 1) { - if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX)) { - cp.add_part(ChatParts::S, suffix); - } - if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END)) { - cp.add_part(ChatParts::S, end); - } - } else { - cp.add_part(ChatParts::S, suffix); - cp.add_part(ChatParts::S, end); - } - } else { - cp.add_part(ChatParts::S, suffix); - cp.add_part(ChatParts::S, end); - } - } - if (alertAssistantAtEnd) { - auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); - cp.add_part(ChatParts::S, assistantBeginPrefix); - } - if (applyGlobalIfAny) { - auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); - cp.add_part(ChatParts::S, globalEnd); - } - cp.dump(); - tagged = cp.str(); - LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str()); - LOGLN("DBUG:%s:%s:CntSys[%d]:CntUsr[%d]:CntOthers[%d]", __func__, tmpl.c_str(), cntSystem, cntUser, cntOthers); - types = cp.get_partstypes(); - lens = cp.get_partslens(); - return true; + return gCT.chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens, curSystemMsgCnt, curUserMsgCnt); } // Given the template standard and a bunch of messages including their roles, this returns @@ -751,7 +774,7 @@ inline int32_t chaton_tmpl_apply_ex_capi( return taggedLength; } -// Copied from common.cpp +// Copied from common.cpp, updated wrt model and logging flow. inline std::vector chaton_llama_tokenize( const struct llama_model * model, const std::string & text,