ChatON: Move core templating/tagging code into ChatTemplates class
However still retain the wrappers, which work with a predefined global instance of ChatTemplates.
This commit is contained in:
parent
600653dae2
commit
4dfd10a40d
1 changed files with 107 additions and 84 deletions
|
@ -409,6 +409,108 @@ public:
|
||||||
return got;
|
return got;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given the template standard and a bunch of messages including their roles, this returns
|
||||||
|
* tagged messages, subPartsTypes string and subPartsLens vector. The returned subParts
|
||||||
|
* types string and lens vector help identify the parts of the tagged msgs string,
|
||||||
|
* which relate to passed msgs and added tags.
|
||||||
|
*
|
||||||
|
* * a string containing the tagged messages
|
||||||
|
* * global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end
|
||||||
|
* * a string where the chars contain info about
|
||||||
|
* type of sub-strings/parts that make up the tagged messages string.
|
||||||
|
* * a vector of ints, which give the length of each part in the tagged messages string.
|
||||||
|
*
|
||||||
|
* If a combination of system-user messages is passed, then tags between the 1st system and
|
||||||
|
* the 1st user message, is based on the flags set wrt the corresponding template standard.
|
||||||
|
* If you dont want this behaviour, pass non 0 values wrt the optional cntSystemMsgCnt and
|
||||||
|
* cntUserMsgCnt arguments.
|
||||||
|
*/
|
||||||
|
bool chaton_tmpl_apply_ex(
|
||||||
|
const std::string &tmpl,
|
||||||
|
const std::vector<const llama_chat_message *> &msgs,
|
||||||
|
bool alertAssistantAtEnd,
|
||||||
|
bool applyGlobalIfAny,
|
||||||
|
std::string &tagged,
|
||||||
|
std::string &types,
|
||||||
|
std::vector<int32_t> &lens,
|
||||||
|
int curSystemMsgCnt = 0,
|
||||||
|
int curUserMsgCnt = 0
|
||||||
|
) {
|
||||||
|
if (!tmpl_exists(tmpl)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ChatParts cp = {};
|
||||||
|
if (applyGlobalIfAny) {
|
||||||
|
std::string globalBegin = tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN});
|
||||||
|
cp.add_part(ChatParts::S, globalBegin);
|
||||||
|
}
|
||||||
|
int cntSystem = curSystemMsgCnt;
|
||||||
|
int cntUser = curUserMsgCnt;
|
||||||
|
int cntOthers = 0;
|
||||||
|
for(const auto msg: msgs) {
|
||||||
|
auto role = msg->role;
|
||||||
|
auto content = msg->content;
|
||||||
|
std::string begin = tmpl_role_getkeys(tmpl, role, {K_BEGIN});
|
||||||
|
auto prefix = tmpl_role_getkeys(tmpl, role, {K_PREFIX});
|
||||||
|
auto suffix = tmpl_role_getkeys(tmpl, role, {K_SUFFIX});
|
||||||
|
auto end = tmpl_role_getkeys(tmpl, role, {K_END});
|
||||||
|
if (role == K_SYSTEM) {
|
||||||
|
cntSystem += 1;
|
||||||
|
cp.add_part(ChatParts::S, begin);
|
||||||
|
cp.add_part(ChatParts::S, prefix);
|
||||||
|
} else if (role == K_USER) {
|
||||||
|
cntUser += 1;
|
||||||
|
if ((cntSystem == 1) && (cntUser == 1)) {
|
||||||
|
if (tmpl_getkey(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN, true)) {
|
||||||
|
cp.add_part(ChatParts::S, begin);
|
||||||
|
}
|
||||||
|
if (tmpl_getkey(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX, true)) {
|
||||||
|
cp.add_part(ChatParts::S, prefix);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cp.add_part(ChatParts::S, begin);
|
||||||
|
cp.add_part(ChatParts::S, prefix);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cntOthers += 1;
|
||||||
|
cp.add_part(ChatParts::S, begin);
|
||||||
|
cp.add_part(ChatParts::S, prefix);
|
||||||
|
}
|
||||||
|
cp.add_part(ChatParts::N, content);
|
||||||
|
if (role == K_SYSTEM) {
|
||||||
|
if (cntSystem == 1) {
|
||||||
|
if (tmpl_getkey(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX, true)) {
|
||||||
|
cp.add_part(ChatParts::S, suffix);
|
||||||
|
}
|
||||||
|
if (tmpl_getkey(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END, true)) {
|
||||||
|
cp.add_part(ChatParts::S, end);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cp.add_part(ChatParts::S, suffix);
|
||||||
|
cp.add_part(ChatParts::S, end);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cp.add_part(ChatParts::S, suffix);
|
||||||
|
cp.add_part(ChatParts::S, end);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (alertAssistantAtEnd) {
|
||||||
|
auto assistantBeginPrefix = tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX});
|
||||||
|
cp.add_part(ChatParts::S, assistantBeginPrefix);
|
||||||
|
}
|
||||||
|
if (applyGlobalIfAny) {
|
||||||
|
auto globalEnd = tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END});
|
||||||
|
cp.add_part(ChatParts::S, globalEnd);
|
||||||
|
}
|
||||||
|
cp.dump();
|
||||||
|
tagged = cp.str();
|
||||||
|
LOGLN("DBUG:CT:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str());
|
||||||
|
LOGLN("DBUG:CT:%s:%s:CntSys[%d]:CntUsr[%d]:CntOthers[%d]", __func__, tmpl.c_str(), cntSystem, cntUser, cntOthers);
|
||||||
|
types = cp.get_partstypes();
|
||||||
|
lens = cp.get_partslens();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -516,17 +618,10 @@ inline bool chaton_tmpl_getkey_bool(const std::string &tmpl, const std::string &
|
||||||
|
|
||||||
|
|
||||||
// Given the template standard and a bunch of messages including their roles, this returns
|
// Given the template standard and a bunch of messages including their roles, this returns
|
||||||
// tagged messages, types string and lens vector. Returned types string and lens vector help
|
// the tagged messages as a string.
|
||||||
// identify the parts of the tagged msgs string, which relate to passed msgs and added tags.
|
// global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end
|
||||||
//
|
//
|
||||||
// * a string containing the tagged messages
|
// Additionally also return info about the parts that make up the tagged message.
|
||||||
// * global-begin + 1 or more [[role-begin] + [role-prefix] + msg + [role-suffix] +[role-end]] + global-end
|
|
||||||
// * a string where the chars contain info about
|
|
||||||
// type of sub-strings/parts that make up the tagged messages string.
|
|
||||||
// * a vector of ints, which give the length of each part in the tagged messages string.
|
|
||||||
//
|
|
||||||
// if a combination of system-user messages is passed, then tags between the system
|
|
||||||
// and the 1st user message, is based on the flags set wrt the corresponding template standard.
|
|
||||||
inline bool chaton_tmpl_apply_ex(
|
inline bool chaton_tmpl_apply_ex(
|
||||||
const std::string &tmpl,
|
const std::string &tmpl,
|
||||||
const std::vector<const llama_chat_message *> &msgs,
|
const std::vector<const llama_chat_message *> &msgs,
|
||||||
|
@ -538,79 +633,7 @@ inline bool chaton_tmpl_apply_ex(
|
||||||
int curSystemMsgCnt = 0,
|
int curSystemMsgCnt = 0,
|
||||||
int curUserMsgCnt = 0
|
int curUserMsgCnt = 0
|
||||||
) {
|
) {
|
||||||
if (!chaton_tmpl_exists(tmpl)) {
|
return gCT.chaton_tmpl_apply_ex(tmpl, msgs, alertAssistantAtEnd, applyGlobalIfAny, tagged, types, lens, curSystemMsgCnt, curUserMsgCnt);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
ChatParts cp = {};
|
|
||||||
if (applyGlobalIfAny) {
|
|
||||||
std::string globalBegin = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_BEGIN});
|
|
||||||
cp.add_part(ChatParts::S, globalBegin);
|
|
||||||
}
|
|
||||||
int cntSystem = curSystemMsgCnt;
|
|
||||||
int cntUser = curUserMsgCnt;
|
|
||||||
int cntOthers = 0;
|
|
||||||
for(const auto msg: msgs) {
|
|
||||||
auto role = msg->role;
|
|
||||||
auto content = msg->content;
|
|
||||||
std::string begin = chaton_tmpl_role_getkeys(tmpl, role, {K_BEGIN});
|
|
||||||
auto prefix = chaton_tmpl_role_getkeys(tmpl, role, {K_PREFIX});
|
|
||||||
auto suffix = chaton_tmpl_role_getkeys(tmpl, role, {K_SUFFIX});
|
|
||||||
auto end = chaton_tmpl_role_getkeys(tmpl, role, {K_END});
|
|
||||||
if (role == K_SYSTEM) {
|
|
||||||
cntSystem += 1;
|
|
||||||
cp.add_part(ChatParts::S, begin);
|
|
||||||
cp.add_part(ChatParts::S, prefix);
|
|
||||||
} else if (role == K_USER) {
|
|
||||||
cntUser += 1;
|
|
||||||
if ((cntSystem == 1) && (cntUser == 1)) {
|
|
||||||
if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_BEGIN)) {
|
|
||||||
cp.add_part(ChatParts::S, begin);
|
|
||||||
}
|
|
||||||
if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_1ST_USER_HAS_PREFIX)) {
|
|
||||||
cp.add_part(ChatParts::S, prefix);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cp.add_part(ChatParts::S, begin);
|
|
||||||
cp.add_part(ChatParts::S, prefix);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cntOthers += 1;
|
|
||||||
cp.add_part(ChatParts::S, begin);
|
|
||||||
cp.add_part(ChatParts::S, prefix);
|
|
||||||
}
|
|
||||||
cp.add_part(ChatParts::N, content);
|
|
||||||
if (role == K_SYSTEM) {
|
|
||||||
if (cntSystem == 1) {
|
|
||||||
if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_SUFFIX)) {
|
|
||||||
cp.add_part(ChatParts::S, suffix);
|
|
||||||
}
|
|
||||||
if (chaton_tmpl_getkey_bool(tmpl, K_SYSTEMUSER_SYSTEM_HAS_END)) {
|
|
||||||
cp.add_part(ChatParts::S, end);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cp.add_part(ChatParts::S, suffix);
|
|
||||||
cp.add_part(ChatParts::S, end);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cp.add_part(ChatParts::S, suffix);
|
|
||||||
cp.add_part(ChatParts::S, end);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (alertAssistantAtEnd) {
|
|
||||||
auto assistantBeginPrefix = chaton_tmpl_role_getkeys(tmpl, K_ASSISTANT, {K_BEGIN, K_PREFIX});
|
|
||||||
cp.add_part(ChatParts::S, assistantBeginPrefix);
|
|
||||||
}
|
|
||||||
if (applyGlobalIfAny) {
|
|
||||||
auto globalEnd = chaton_tmpl_role_getkeys(tmpl, K_GLOBAL, {K_END});
|
|
||||||
cp.add_part(ChatParts::S, globalEnd);
|
|
||||||
}
|
|
||||||
cp.dump();
|
|
||||||
tagged = cp.str();
|
|
||||||
LOGLN("DBUG:%s:%s:%s", __func__, tmpl.c_str(), tagged.c_str());
|
|
||||||
LOGLN("DBUG:%s:%s:CntSys[%d]:CntUsr[%d]:CntOthers[%d]", __func__, tmpl.c_str(), cntSystem, cntUser, cntOthers);
|
|
||||||
types = cp.get_partstypes();
|
|
||||||
lens = cp.get_partslens();
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given the template standard and a bunch of messages including their roles, this returns
|
// Given the template standard and a bunch of messages including their roles, this returns
|
||||||
|
@ -751,7 +774,7 @@ inline int32_t chaton_tmpl_apply_ex_capi(
|
||||||
return taggedLength;
|
return taggedLength;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copied from common.cpp
|
// Copied from common.cpp, updated wrt model and logging flow.
|
||||||
inline std::vector<llama_token> chaton_llama_tokenize(
|
inline std::vector<llama_token> chaton_llama_tokenize(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const std::string & text,
|
const std::string & text,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue