From 0cfe99076dc930377021b2a40bc69271ac0d9740 Mon Sep 17 00:00:00 2001 From: HanishKVC Date: Mon, 13 May 2024 16:51:07 +0530 Subject: [PATCH] ChatON:ChatTemplates: TmplExists, TmplGetKey, TmplRoleGetKeys ChatTemplate directly supports these now, as well as the existing global instance based corresponding helpers depend on same. --- common/chaton.hpp | 93 ++++++++++++++++++++++++++---------------- examples/main/main.cpp | 8 ++-- 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/common/chaton.hpp b/common/chaton.hpp index 7a633c13a..13ec3a853 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -305,6 +305,43 @@ public: ChatTemplates(GroupKVMapMapVariant defaultMap) : GroupKV(defaultMap) {} + /** + * Check if the specified chat-template exists or not. + * NOTE: This doesnt cross check, if the template inturn contains all the required fields or not. + */ + bool tmpl_exists(const std::string &tmpl) { + if (!group_exists(tmpl)) { + LOG_TEELN("WARN:CT:%s: tmpl[%s] not found...", __func__, tmpl.c_str()); + return false; + } + return true; + } + + /** + * For the specified chat-template, get the value associated with the specified key/field. + */ + template + SupportedDataType tmpl_getkey(const std::string &tmpl, const std::string &key, const SupportedDataType &defaultValue) { + return get_value(tmpl, {key}, defaultValue, "CTTmplGetKey"); + } + + /** + * For the specified chat-template and the role within, cumulate the values of the specified keys/fields + * and return the same. + */ + std::string tmpl_role_getkeys(const std::string &tmpl, const std::string &role, const std::vector &keys) { + std::string got = ""; + std::string sKeys = ""; + for(auto key: keys) { + got += get_value(tmpl, {role, key}, "", "CTTmplRoleGetKeys"); + sKeys += "+"; + sKeys += key; + } + LDBUG_LN("DBUG:CT:%s:%s:%s:%s:%s", __func__, tmpl.c_str(), role.c_str(), sKeys.c_str(), got.c_str()); + return got; + } + + }; #include "chaton_meta.hpp" @@ -394,35 +431,19 @@ inline bool chaton_meta_load(const std::string &fname) { inline bool chaton_tmpl_exists(const std::string &tmpl) { - if (!gCT.group_exists(tmpl)) { - LOG_TEELN("WARN:%s: tmpl[%s] not found...", __func__, tmpl.c_str()); - return false; - } - return true; + return gCT.tmpl_exists(tmpl); } -inline std::string chaton_tmpl_role_kv(const std::string &tmpl, const std::string &role, const std::vector &keys) { - std::string got = ""; - std::string sKeys = ""; - for(auto key: keys) { - got += gCT.get_value(tmpl, {role, key}, ""); - sKeys += "+"; - sKeys += key; - } - LOGLN("DBUG:%s:%s:%s:%s:%s", __func__, tmpl.c_str(), role.c_str(), sKeys.c_str(), got.c_str()); - return got; +inline std::string chaton_tmpl_role_getkeys(const std::string &tmpl, const std::string &role, const std::vector &keys) { + return gCT.tmpl_role_getkeys(tmpl, role, keys); } -inline std::string chaton_tmpl_kv(const std::string &tmpl, const std::string &key) { - std::string got = gCT.get_value(tmpl, {key}, ""); - LOGLN("DBUG:%s:%s:%s:%s", __func__, tmpl.c_str(), key.c_str(), got.c_str()); - return got; +inline std::string chaton_tmpl_getkey_str(const std::string &tmpl, const std::string &key) { + return gCT.tmpl_getkey(tmpl, {key}, ""); } -inline bool chaton_tmpl_kv_bool(const std::string &tmpl, const std::string &key) { - bool got = gCT.get_value(tmpl, {key}, false); - LOGLN("DBUG:%s:%s:%s:%d", __func__, tmpl.c_str(), key.c_str(), got); - return got; +inline bool chaton_tmpl_getkey_bool(const std::string &tmpl, const std::string &key) { + return gCT.tmpl_getkey(tmpl, {key}, false); } @@ -446,8 +467,8 @@ inline bool chaton_tmpl_apply_single_ex( return false; } ChatParts cp = {}; - std::string beginPrefix = chaton_tmpl_role_kv(tmpl, role, {K_BEGIN, K_PREFIX}); - std::string suffixEnd = chaton_tmpl_role_kv(tmpl, role, {K_SUFFIX, K_END}); + std::string beginPrefix = chaton_tmpl_role_getkeys(tmpl, role, {K_BEGIN, K_PREFIX}); + std::string suffixEnd = chaton_tmpl_role_getkeys(tmpl, role, {K_SUFFIX, K_END}); cp.add_part(ChatParts::S, beginPrefix); cp.add_part(ChatParts::N, content); cp.add_part(ChatParts::S, suffixEnd); @@ -531,7 +552,7 @@ inline bool chaton_tmpl_apply_ex( return false; } ChatParts cp = {}; - std::string globalBegin = chaton_tmpl_role_kv(tmpl, K_GLOBAL, {K_BEGIN}); + std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); cp.add_part(ChatParts::S, globalBegin); int cntSystem = 0; int cntUser = 0; @@ -539,10 +560,10 @@ inline bool chaton_tmpl_apply_ex( for(const auto msg: msgs) { auto role = msg->role; auto content = msg->content; - std::string begin = chaton_tmpl_role_kv(tmpl, role, {K_BEGIN}); - auto prefix = chaton_tmpl_role_kv(tmpl, role, {K_PREFIX}); - auto suffix = chaton_tmpl_role_kv(tmpl, role, {K_SUFFIX}); - auto end = chaton_tmpl_role_kv(tmpl, role, {K_END}); + std::string begin = chaton_tmpl_role_getkeys(tmpl, role, {K_BEGIN}); + auto prefix = chaton_tmpl_role_getkeys(tmpl, role, {K_PREFIX}); + auto suffix = chaton_tmpl_role_getkeys(tmpl, role, {K_SUFFIX}); + auto end = chaton_tmpl_role_getkeys(tmpl, role, {K_END}); if (role == K_SYSTEM) { cntSystem += 1; cp.add_part(ChatParts::S, begin); @@ -550,10 +571,10 @@ inline bool chaton_tmpl_apply_ex( } else if (role == K_USER) { cntUser += 1; if ((cntSystem == 1) && (cntUser == 1)) { - if (chaton_tmpl_kv_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN)) { + if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN)) { cp.add_part(ChatParts::S, begin); } - if (chaton_tmpl_kv_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX)) { + if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX)) { cp.add_part(ChatParts::S, prefix); } } else { @@ -567,10 +588,10 @@ inline bool chaton_tmpl_apply_ex( } cp.add_part(ChatParts::N, content); if (role == K_SYSTEM) { - if (chaton_tmpl_kv_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX)) { + if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX)) { cp.add_part(ChatParts::S, suffix); } - if (chaton_tmpl_kv_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END)) { + if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END)) { cp.add_part(ChatParts::S, end); } } else { @@ -579,10 +600,10 @@ inline bool chaton_tmpl_apply_ex( } } if (alertAssistantAtEnd) { - auto assistantBeginPrefix = chaton_tmpl_role_kv(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); + auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); cp.add_part(ChatParts::S, assistantBeginPrefix); } - auto globalEnd = chaton_tmpl_role_kv(tmpl, K_GLOBAL, {K_END}); + auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); cp.add_part(ChatParts::S, globalEnd); cp.dump(); tagged = cp.str(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a600d16b5..1a982d4ce 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -380,13 +380,13 @@ int main(int argc, char ** argv) { } // chaton mode - const auto chaton_assitant_prefix = ::llama_tokenize(ctx, chaton_tmpl_role_kv(params.chaton_template_id, K_ASSISTANT, {K_BEGIN, K_PREFIX}), false, true); + const auto chaton_assitant_prefix = ::llama_tokenize(ctx, chaton_tmpl_role_getkeys(params.chaton_template_id, K_ASSISTANT, {K_BEGIN, K_PREFIX}), false, true); if (params.chaton) { params.interactive = true; // may remove later, by requiring user to explicitly request interactive mode params.interactive_first = true; - params.input_prefix = chaton_tmpl_role_kv(params.chaton_template_id, K_USER, {K_BEGIN, K_PREFIX}); - params.input_suffix = chaton_tmpl_role_kv(params.chaton_template_id, K_USER, {K_SUFFIX, K_END}); - params.antiprompt.emplace_back(chaton_tmpl_kv(params.chaton_template_id, K_REVERSE_PROMPT)); + params.input_prefix = chaton_tmpl_role_getkeys(params.chaton_template_id, K_USER, {K_BEGIN, K_PREFIX}); + params.input_suffix = chaton_tmpl_role_getkeys(params.chaton_template_id, K_USER, {K_SUFFIX, K_END}); + params.antiprompt.emplace_back(chaton_tmpl_getkey_str(params.chaton_template_id, K_REVERSE_PROMPT)); } // enable interactive mode if interactive start is specified