diff --git a/common/chaton.hpp b/common/chaton.hpp index 7ea7b1932..d74dc08de 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -165,11 +165,27 @@ public: }; inline bool chaton_meta_load(std::string &fname) { + if (conMeta != nullptr) { + LOGXLN("WARN:%s:ChatOn Meta: overwriting???", __func__); + } std::ifstream f(fname); conMeta = json::parse(f); return true; } +inline bool chaton_tmpl_exists(const std::string &tmpl) { + if (conMeta == nullptr) { + LOG_TEELN("ERRR:%s:ChatOnMeta: Not loaded yet...", __func__); + return false; + } + try { + auto tmplData = conMeta[tmpl]; + return true; + } catch (json::exception &err) { + LOG_TEELN("WARN:%s:ChatOnMeta: tmpl[%s] not found...", __func__, tmpl.c_str()); + return false; + } +} inline std::string chaton_tmpl_role_kv(const std::string &tmpl, const std::string &role, const std::vector &keys) { std::string got = ""; @@ -230,12 +246,50 @@ inline bool chaton_tmpl_apply_single_ex( } // Return user-prefix + msg + user-suffix, types string and lens vector wrt the parts that make up the returned string -inline std::string chaton_tmpl_apply_single(const std::string &tmpl, const std::string &role, const std::string &content) { - std::string tagged; +inline size_t chaton_tmpl_apply_single( + const std::string &tmpl, + const std::string &role, + const std::string &content, + std::string &tagged + ) { + if (!chaton_tmpl_exists(tmpl)) { + return -1; + } std::string types; std::vector lens; chaton_tmpl_apply_single_ex(tmpl, role, content, tagged, types, lens); - return tagged; + return tagged.size(); +} + +/** + * Apply chat-handshake-template for the specified template standard and role. + * If the passed char array is smaller that that required for the tagged message, + * * part of the tagged message which fits within dest buffer is copied + * * the returned value, indicates the size of the tagged message + * NOTE: + * * the passed char array should be able to fit the tagged message+0|null char. + * * if the return value from this function is larger than or equal to destLength, + * then you will have to increase the size of the dest buffer, and call this + * function a second time, to ensure that one gets the full tagged message. + */ +inline size_t chat_tmpl_apply_single_capi( + const char *tmpl, + const char *role, + const char *content, + char *dest, + const size_t destLength + ) { + std::string tagged; + std::string types; + std::vector lens; + auto taggedLength = chaton_tmpl_apply_single(tmpl, role, content, tagged); + if (taggedLength <= 0) { + return taggedLength; + } + if (dest && (destLength > 0)) { + strlcpy(dest, tagged.c_str(), destLength); + } + return taggedLength; } // global-begin + [[role-begin] + [role-prefix] + msg + role-suffix] + global-end diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 074db4598..2a41c8e1e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char ** argv) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } if (params.chaton) { - params.prompt = chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt); + chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt, params.prompt); } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else {