diff --git a/common/chaton.hpp b/common/chaton.hpp index 6bd8250de..c16be0ee0 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -541,6 +541,7 @@ inline int32_t chaton_tmpl_apply( // If the passed char array is smaller than that required for the tagged messages string, // * part of the tagged messages string which fits within dest buffer is copied // * the returned value, indicates the size of the actual tagged message +// // NOTE: // * ideally the passed char array should be able to fit the tagged messages string + 0|null char. // * if the return value from this function is larger than or equal to destLength, @@ -572,6 +573,58 @@ inline int32_t chaton_tmpl_apply_capi( return taggedLength; } +// +// In addition to the semantic provided by chaton_tmpl_apply_capi +// this additionally also returns info about the parts that make up +// the returned tagged message. +// +// partTypes and partLengths should be arrays that can accomodate the +// same number of elements belonging to its respective type. +// Inturn the pNumParts should point to a int which specifies the +// number of elements. +// If the generated tagged message has more parts than the specified +// *pNumParts, then the logic copies partTypes and partLengths to the +// specified length/NumOfParts only. Parallely it updates *pNumParts +// to the actual needed length (not including any terminating null char or so). +// +inline int32_t chaton_tmpl_apply_ex_capi( + const char *tmpl, + const struct llama_chat_message *msgs, + const size_t numMsgs, + bool alertAssistantAtEnd, + char *dest, + int32_t destLength, + char *partTypes, + int32_t *partLengths, + int32_t *pNumParts + ) { + if ((tmpl == nullptr) || (dest == nullptr)) { + return -1; + } + std::vector vMsgs; + for(size_t i=0; i lens; + int32_t taggedLength = chaton_tmpl_apply_ex(tmpl, vMsgs, taggedMsgs, types, lens, alertAssistantAtEnd); + if (taggedLength <= 0) { + return taggedLength; + } + if (destLength > 0) { + strlcpy(dest, taggedMsgs.c_str(), destLength); + } + if (*pNumParts > 0) { + strlcpy(partTypes, types.c_str(), *pNumParts); + for(int i=0; i < *pNumParts; i++) { + partLengths[i] = lens[i]; + } + } + *pNumParts = types.length(); + return taggedLength; +} + /** * if tmpl is