diff --git a/common/chaton.hpp b/common/chaton.hpp index 94c8b4c8c..3e69c5fda 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -531,6 +531,7 @@ inline bool chaton_tmpl_apply_ex( const std::string &tmpl, const std::vector &msgs, bool alertAssistantAtEnd, + bool applyGlobalIfAny, std::string &tagged, std::string &types, std::vector &lens @@ -539,8 +540,10 @@ inline bool chaton_tmpl_apply_ex( return false; } ChatParts cp = {}; - std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); - cp.add_part(ChatParts::S, globalBegin); + if (applyGlobalIfAny) { + std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN}); + cp.add_part(ChatParts::S, globalBegin); + } int cntSystem = 0; int cntUser = 0; int cntOthers = 0; @@ -590,8 +593,10 @@ inline bool chaton_tmpl_apply_ex( auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX}); cp.add_part(ChatParts::S, assistantBeginPrefix); } - auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); - cp.add_part(ChatParts::S, globalEnd); + if (applyGlobalIfAny) { + auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END}); + cp.add_part(ChatParts::S, globalEnd); + } cp.dump(); tagged = cp.str(); LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str()); @@ -608,11 +613,12 @@ inline int32_t chaton_tmpl_apply( const std::string &tmpl, const std::vector &msgs, bool alertAssistantAtEnd, + bool applyGlobalIfAny, std::string &tagged ) { std::string types; std::vector lens; - if (!chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, tagged, types, lens)) { + if (!chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens)) { return -1; } return tagged.size(); @@ -629,12 +635,13 @@ inline size_t chaton_tmpl_apply_single( const std::string &role, const std::string &content, bool alertAssistantAtEnd, + bool applyGlobalIfAny, std::string &tagged ) { std::string types; std::vector lens; llama_chat_message cm {role.c_str(), content.c_str()}; - if (!chaton_tmpl_apply_ex(tmpl, {&cm}, alertAssistantAtEnd, tagged, types, lens)) { + if (!chaton_tmpl_apply_ex(tmpl, {&cm}, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens)) { return -1; } return tagged.size(); @@ -669,7 +676,7 @@ inline int32_t chaton_tmpl_apply_capi( vMsgs.push_back(&msgs[i]); } std::string taggedMsgs; - int32_t taggedLength = chaton_tmpl_apply(tmpl, vMsgs, alertAssistantAtEnd, taggedMsgs); + int32_t taggedLength = chaton_tmpl_apply(tmpl, vMsgs, alertAssistantAtEnd, true, taggedMsgs); if (taggedLength < 0) { return taggedLength; } @@ -714,7 +721,7 @@ inline int32_t chaton_tmpl_apply_ex_capi( std::string taggedMsgs; std::string types; std::vector lens; - if (!chaton_tmpl_apply_ex(tmpl, vMsgs, alertAssistantAtEnd, taggedMsgs, types, lens)) { + if (!chaton_tmpl_apply_ex(tmpl, vMsgs, alertAssistantAtEnd, true, taggedMsgs, types, lens)) { return -1; } int32_t taggedLength = taggedMsgs.size(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1a982d4ce..f86592fc7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } if (params.chaton) { - chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt, params.prompt); + chaton_tmpl_apply_single(params.chaton_template_id, K_SYSTEM, params.prompt, false, false, params.prompt); } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else {