ChatON: Initial go at chat-template-apply c-api with parts info

This commit is contained in:
HanishKVC 2024-05-07 11:08:47 +05:30
parent f6a86cd209
commit 04b4a15177

View file

@ -541,6 +541,7 @@ inline int32_t chaton_tmpl_apply(
// If the passed char array is smaller than that required for the tagged messages string,
// * part of the tagged messages string which fits within dest buffer is copied
// * the returned value, indicates the size of the actual tagged message
//
// NOTE:
// * ideally the passed char array should be able to fit the tagged messages string + 0|null char.
// * if the return value from this function is larger than or equal to destLength,
@ -572,6 +573,58 @@ inline int32_t chaton_tmpl_apply_capi(
return taggedLength;
}
//
// In addition to the semantic provided by chaton_tmpl_apply_capi
// this additionally also returns info about the parts that make up
// the returned tagged message.
//
// partTypes and partLengths should be arrays that can accomodate the
// same number of elements belonging to its respective type.
// Inturn the pNumParts should point to a int which specifies the
// number of elements.
// If the generated tagged message has more parts than the specified
// *pNumParts, then the logic copies partTypes and partLengths to the
// specified length/NumOfParts only. Parallely it updates *pNumParts
// to the actual needed length (not including any terminating null char or so).
//
inline int32_t chaton_tmpl_apply_ex_capi(
const char *tmpl,
const struct llama_chat_message *msgs,
const size_t numMsgs,
bool alertAssistantAtEnd,
char *dest,
int32_t destLength,
char *partTypes,
int32_t *partLengths,
int32_t *pNumParts
) {
if ((tmpl == nullptr) || (dest == nullptr)) {
return -1;
}
std::vector<const llama_chat_message *> vMsgs;
for(size_t i=0; i<numMsgs; i++) {
vMsgs.push_back(&msgs[i]);
}
std::string taggedMsgs;
std::string types;
std::vector<int> lens;
int32_t taggedLength = chaton_tmpl_apply_ex(tmpl, vMsgs, taggedMsgs, types, lens, alertAssistantAtEnd);
if (taggedLength <= 0) {
return taggedLength;
}
if (destLength > 0) {
strlcpy(dest, taggedMsgs.c_str(), destLength);
}
if (*pNumParts > 0) {
strlcpy(partTypes, types.c_str(), *pNumParts);
for(int i=0; i < *pNumParts; i++) {
partLengths[i] = lens[i];
}
}
*pNumParts = types.length();
return taggedLength;
}
/**
* if tmpl is