From b6da7d9c9dd5b82e3e017ca2f7d9e0e031f3e1b1 Mon Sep 17 00:00:00 2001 From: HanishKVC Date: Wed, 8 May 2024 18:29:48 +0530 Subject: [PATCH] ChatON: tokenize keeping in mind the taggedMessage subparts Initial go --- common/chaton.hpp | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/common/chaton.hpp b/common/chaton.hpp index 9fa661461..c7cf5de2b 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -628,6 +628,51 @@ inline int32_t chaton_tmpl_apply_ex_capi( return taggedLength; } +// Copied from common.cpp +std::vector chaton_llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_special, + bool parse_special) { + LOGLN("DBUG:%s:%s:special[add:%d, parse:%d]", __func__, text.c_str(), add_special, parse_special); + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +// Tokenize the passed taggedText, keeping in mind the subparts within and +// inturn whether to parse special tokens in them or not (partsTypes). +std::vector chaton_llama_tokenize_ex( + const llama_context *ctx, + const std::string &taggedText, + const std::string &partsTypes, + const std::vector &partsLengths, + bool addSpecial + ) { + std::vector tokens; + int iPart = 0; + int iStart = 0; + for(auto partLen: partsLengths) { + auto partType = partsTypes[iPart]; + iPart += 1; + auto msgPart = taggedText.substr(iStart, partLen); + iStart += partLen; + auto parseSpecial = partType == ChatParts::S ? true : false; + auto curTokens = chaton_llama_tokenize(llama_get_model(ctx), msgPart, addSpecial, parseSpecial); + tokens.insert(tokens.end(), curTokens.begin(), curTokens.end()); + } + return tokens; +} + /** * if tmpl is